diff --git a/Cargo.lock b/Cargo.lock index 10ee1ea60..5aadd4303 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1744,6 +1744,7 @@ dependencies = [ "testresult", "thiserror 2.0.17", "tokio", + "tokio-stream", "tokio-tungstenite 0.27.0", "toml", "tower-http", diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index f7a044c94..0bf03ed0f 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -77,6 +77,7 @@ opentelemetry_sdk = { optional = true, version = "0.31", features = ["rt-tokio"] # internal deps freenet-stdlib = { features = ["net"], workspace = true } console-subscriber = { version = "0.4.1", optional = true } +tokio-stream = "0.1.17" [target.'cfg(windows)'.dependencies] winapi = { version = "0.3", features = ["sysinfoapi"] } diff --git a/crates/core/src/node/network_bridge/handshake.rs b/crates/core/src/node/network_bridge/handshake.rs index e930346fd..b127244d8 100644 --- a/crates/core/src/node/network_bridge/handshake.rs +++ b/crates/core/src/node/network_bridge/handshake.rs @@ -6,9 +6,9 @@ use std::{ sync::{atomic::AtomicBool, Arc}, }; use tokio::time::{timeout, Duration}; -use tracing::{instrument, Instrument}; +use tracing::Instrument; -use futures::{future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::BoxFuture, stream::FuturesUnordered, Future, FutureExt, TryFutureExt}; use tokio::sync::mpsc::{self}; use crate::{ @@ -55,8 +55,6 @@ pub(super) enum HandshakeError { #[derive(Debug)] pub(super) enum Event { - // todo: instead of returning InboundJoinReq which is an internal event - // return a proper well formed ConnectOp and any other types needed (PeerConnection etc.) /// An inbound connection to a peer was successfully established at a gateway. InboundConnection { id: Transaction, @@ -99,7 +97,9 @@ pub(super) enum Event { }, } -#[allow(clippy::large_enum_variant)] +/// NOTE: This enum is no longer used but kept for reference during transition. +/// The Stream implementation infers the forward result from forward_conn's ConnectState. +#[allow(dead_code, clippy::large_enum_variant)] enum ForwardResult { Forward(PeerId, NetMessage, ConnectivityInfo), DirectlyAccepted(ConnectivityInfo), @@ -269,487 +269,6 @@ impl HandshakeHandler { ) } - /// Processes events related to connection establishment and management. - #[instrument(skip(self))] - pub async fn wait_for_events(&mut self) -> Result { - loop { - tracing::trace!( - "wait_for_events loop iteration - unconfirmed: {}, ongoing_outbound: {}", - self.unconfirmed_inbound_connections.len(), - self.ongoing_outbound_connections.len() - ); - tokio::select! { - // Handle new inbound connections - new_conn = self.inbound_conn_handler.next_connection() => { - let Some(conn) = new_conn else { - return Err(HandshakeError::ChannelClosed); - }; - tracing::debug!(from=%conn.remote_addr(), "New inbound connection"); - self.track_inbound_connection(conn); - } - // Process outbound connection attempts - outbound_conn = self.ongoing_outbound_connections.next(), if !self.ongoing_outbound_connections.is_empty() => { - let r = match outbound_conn { - Some(Ok(InternalEvent::OutboundConnEstablished(peer_id, connection))) => { - tracing::info!(at=?connection.my_address(), from=%connection.remote_addr(), "Outbound connection successful"); - Ok(Event::OutboundConnectionSuccessful { peer_id, connection }) - } - Some(Ok(InternalEvent::OutboundGwConnEstablished(id, connection))) => { - tracing::info!(at=?connection.my_address(), from=%connection.remote_addr(), "Outbound gateway connection successful"); - if let Some(addr) = connection.my_address() { - tracing::debug!(%addr, "Attempting setting own peer key"); - self.connection_manager.try_set_peer_key(addr); - - // For non-gateway peers: mark as ready to accept client operations - if let Some(ref peer_ready) = self.peer_ready { - peer_ready.store(true, std::sync::atomic::Ordering::SeqCst); - tracing::info!("Peer initialization complete: peer_ready set to true, client operations now enabled"); - } - - if self.this_location.is_none() { - // in the case trust locations is set to true, this peer already had its location set - self.connection_manager.update_location(Some(Location::from_address(&addr))); - } - } - tracing::debug!(at=?connection.my_address(), from=%connection.remote_addr(), "Outbound connection to gw successful"); - self.wait_for_gw_confirmation(id, connection, Ring::DEFAULT_MAX_HOPS_TO_LIVE).await?; - continue; - } - Some(Ok(InternalEvent::FinishedOutboundConnProcess(tracker))) => { - self.connecting.remove(&tracker.gw_peer.peer.addr); - // at this point we are done checking all the accepts inbound from a transient gw conn - tracing::debug!(at=?tracker.gw_conn.my_address(), gw=%tracker.gw_conn.remote_addr(), "Done checking, connection not accepted by gw, dropping connection"); - Ok(Event::OutboundGatewayConnectionRejected { peer_id: tracker.gw_peer.peer }) - } - Some(Ok(InternalEvent::OutboundGwConnConfirmed(tracker))) => { - tracing::debug!(at=?tracker.gw_conn.my_address(), from=%tracker.gw_conn.remote_addr(), "Outbound connection to gw confirmed"); - self.connected.insert(tracker.gw_conn.remote_addr()); - self.connecting.remove(&tracker.gw_conn.remote_addr()); - return Ok(Event::OutboundGatewayConnectionSuccessful { - peer_id: tracker.gw_peer.peer, - connection: tracker.gw_conn, - remaining_checks: tracker.remaining_checks, - }); - } - Some(Ok(InternalEvent::NextCheck(tracker))) => { - self.ongoing_outbound_connections.push( - check_remaining_hops(tracker).boxed() - ); - continue; - } - Some(Ok(InternalEvent::RemoteConnectionAttempt { remote, tracker })) => { - // this shouldn't happen as the tx would exit this module - // see: OutboundGwConnConfirmed - debug_assert!(!tracker.gw_accepted); - tracing::debug!( - at=?tracker.gw_conn.my_address(), - gw=%tracker.gw_conn.remote_addr(), - "Attempting remote connection to {remote}" - ); - self.start_outbound_connection(remote.clone(), tracker.tx, false).await; - let current_span = tracing::Span::current(); - let checking_hops_span = tracing::info_span!(parent: current_span, "checking_hops"); - self.ongoing_outbound_connections.push( - check_remaining_hops(tracker).instrument(checking_hops_span).boxed() - ); - continue; - } - Some(Ok(InternalEvent::DropInboundConnection(addr))) => { - self.connecting.remove(&addr); - self.outbound_messages.remove(&addr); - continue; - } - Some(Err((peer_id, error))) => { - tracing::debug!(from=%peer_id.addr, "Outbound connection failed: {error}"); - self.connecting.remove(&peer_id.addr); - self.outbound_messages.remove(&peer_id.addr); - self.connection_manager.prune_alive_connection(&peer_id); - Ok(Event::OutboundConnectionFailed { peer_id, error }) - } - Some(Ok(other)) => { - tracing::error!("Unexpected event: {other:?}"); - continue; - } - None => Err(HandshakeError::ChannelClosed), - }; - break r; - } - // Handle unconfirmed inbound connections (only applies in gateways) - unconfirmed_inbound_conn = self.unconfirmed_inbound_connections.next(), if !self.unconfirmed_inbound_connections.is_empty() => { - tracing::debug!("Processing unconfirmed inbound connection"); - let Some(res) = unconfirmed_inbound_conn else { - return Err(HandshakeError::ChannelClosed); - }; - let (event, outbound_sender) = res?; - tracing::debug!("Unconfirmed connection event: {:?}", event); - match event { - InternalEvent::InboundGwJoinRequest(mut req) => { - let location = if let Some((_, other)) = self.this_location.zip(req.location) { - other - } else { - Location::from_address(&req.conn.remote_addr()) - }; - let should_accept = self.connection_manager.should_accept(location, &req.joiner); - // Check if this is a valid acceptance scenario - // Non-gateways with 0 connections should not accept (they need existing connections to forward through) - let can_accept = should_accept && - (self.is_gateway || self.connection_manager.num_connections() > 0); - - if can_accept { - let accepted_msg = NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Response { - id: req.id, - sender: self.connection_manager.own_location(), - target: PeerKeyLocation { - peer: req.joiner.clone(), - location: Some(location), - }, - msg: ConnectResponse::AcceptedBy { - accepted: true, - acceptor: self.connection_manager.own_location(), - joiner: req.joiner.clone(), - }, - })); - - tracing::debug!(at=?req.conn.my_address(), from=%req.conn.remote_addr(), "Accepting connection"); - - if let Err(e) = req.conn.send(accepted_msg).await { - tracing::error!(%e, "Failed to send accepted message from gw, pruning reserved connection"); - self.connection_manager.prune_in_transit_connection(&req.joiner); - return Err(e.into()); - } - - let InboundGwJoinRequest { - conn, - id, - hops_to_live, - max_hops_to_live, - skip_connections, - skip_forwards, - joiner, - .. - } = req; - - // Forward the connection or accept it directly - let forward_result = { - // TODO: refactor this so it happens in the background out of the main handler loop - let mut nw_bridge = ForwardPeerMessage { - msg: parking_lot::Mutex::new(None), - }; - - let my_peer_id = self.connection_manager.own_location(); - - let joiner_pk_loc = PeerKeyLocation { - peer: joiner.clone(), - location: Some(location), - }; - - let mut skip_connections = skip_connections.clone(); - let mut skip_forwards = skip_forwards.clone(); - skip_connections.insert(my_peer_id.peer.clone()); - skip_forwards.insert(my_peer_id.peer.clone()); - - let forward_info = ForwardParams { - left_htl: hops_to_live, - max_htl: max_hops_to_live, - accepted: true, - skip_connections, - skip_forwards, - req_peer: my_peer_id.clone(), - joiner: joiner_pk_loc.clone(), - is_gateway: self.is_gateway, - }; - - let f = forward_conn( - id, - &self.connection_manager, - self.router.clone(), - &mut nw_bridge, - forward_info - ); - - match f.await { - Err(err) => { - tracing::error!(%err, "Error forwarding connection"); - continue; - } - Ok(Some(conn_state)) => { - let ConnectState::AwaitingConnectivity(info) = conn_state else { - unreachable!("forward_conn should return AwaitingConnectivity if successful") - }; - - // Check if we have a forward message (forwarding) or not (direct acceptance) - if let Some((forward_target, msg)) = nw_bridge.msg.into_inner() { - (Some(ForwardResult::Forward(forward_target.clone(), msg, info)), Some(forward_target)) - } else if info.is_bootstrap_acceptance { - // Gateway bootstrap case: connection should be registered immediately - // This bypasses the normal CheckConnectivity flow. See forward_conn() - // bootstrap logic in connect.rs for full explanation. - (Some(ForwardResult::BootstrapAccepted(info)), None) - } else { - // Normal direct acceptance - will wait for CheckConnectivity - (Some(ForwardResult::DirectlyAccepted(info)), None) - } - } - Ok(None) => (None, None), - } - }; - - tracing::info!(%id, %joiner, "Creating InboundConnection event"); - match forward_result { - (Some(ForwardResult::Forward(forward_target, msg, info)), _) => { - return Ok(Event::InboundConnection { - id, - conn, - joiner, - op: Some(Box::new(ConnectOp::new(id, Some(ConnectState::AwaitingConnectivity(info)), None, None))), - forward_info: Some(Box::new(ForwardInfo { target: forward_target, msg })), - is_bootstrap: false, - }); - } - (Some(ForwardResult::BootstrapAccepted(info)), _) => { - return Ok(Event::InboundConnection { - id, - conn, - joiner, - op: Some(Box::new(ConnectOp::new(id, Some(ConnectState::AwaitingConnectivity(info)), None, None))), - forward_info: None, - is_bootstrap: true, - }); - } - (Some(ForwardResult::DirectlyAccepted(info)), _) => { - return Ok(Event::InboundConnection { - id, - conn, - joiner, - op: Some(Box::new(ConnectOp::new(id, Some(ConnectState::AwaitingConnectivity(info)), None, None))), - forward_info: None, - is_bootstrap: false, - }); - } - (Some(ForwardResult::Rejected), _) | (None, _) => { - return Ok(Event::InboundConnection { - id, - conn, - joiner, - op: None, - forward_info: None, - is_bootstrap: false, - }); - } - } - - } else { - // If should_accept was true but we can't actually accept (non-gateway with 0 connections), - // we need to clean up the reserved connection - if should_accept && !can_accept { - self.connection_manager.prune_in_transit_connection(&req.joiner); - tracing::debug!( - "Non-gateway with 0 connections cannot accept connection from {:?}", - req.joiner - ); - } - - let InboundGwJoinRequest { - mut conn, - id, - hops_to_live, - max_hops_to_live, - skip_connections, - skip_forwards, - joiner, - .. - } = req; - let remote = conn.remote_addr(); - tracing::debug!(at=?conn.my_address(), from=%remote, "Transient connection"); - let mut tx = TransientConnection { - tx: id, - joiner: joiner.clone(), - max_hops_to_live, - hops_to_live, - skip_connections, - skip_forwards, - }; - match self.forward_transient_connection(&mut conn, &mut tx).await { - Ok(ForwardResult::Forward(forward_target, msg, info)) => { - self.unconfirmed_inbound_connections.push( - gw_transient_peer_conn( - conn, - outbound_sender, - tx, - info, - ).boxed() - ); - return Ok(Event::TransientForwardTransaction { - target: remote, - tx: id, - forward_to: forward_target, - msg: Box::new(msg), - }); - } - Ok(ForwardResult::BootstrapAccepted(info)) => { - // Gateway bootstrap: First connection should be registered immediately. - // This bypasses the normal transient connection flow. - // See forward_conn() in connect.rs for full explanation. - return Ok(Event::InboundConnection { - id, - conn, - joiner, - op: Some(Box::new(ConnectOp::new(id, Some(ConnectState::AwaitingConnectivity(info)), None, None))), - forward_info: None, - is_bootstrap: true, - }); - } - Ok(ForwardResult::DirectlyAccepted(_info)) => { - // Connection was accepted directly (not forwarded) - // For now, treat this as a rejection since we shouldn't hit this case in transient connections - // Clean up the reserved connection slot - self.connection_manager.prune_in_transit_connection(&joiner); - self.outbound_messages.remove(&remote); - self.connecting.remove(&remote); - return Ok(Event::InboundConnectionRejected { peer_id: joiner }); - } - Ok(ForwardResult::Rejected) => { - // Clean up the reserved connection slot - self.connection_manager.prune_in_transit_connection(&joiner); - self.outbound_messages.remove(&remote); - self.connecting.remove(&remote); - return Ok(Event::InboundConnectionRejected { peer_id: joiner }); - } - Err(e) => { - tracing::error!(from=%remote, "Error forwarding transient connection: {e}"); - return Err(e); - } - } - } - } - InternalEvent::DropInboundConnection(addr) => { - self.outbound_messages.remove(&addr); - self.connecting.remove(&addr); - continue; - } - other => { - tracing::error!("Unexpected event: {other:?}"); - continue; - } - } - } - // Process pending messages for unconfirmed connections - pending_msg = self.pending_msg_rx.recv() => { - let Some((addr, msg)) = pending_msg else { - return Err(HandshakeError::ChannelClosed); - }; - if let Some(event) = self.outbound(addr, msg).await { - break Ok(event); - } - } - // Handle requests to establish new connections - establish_connection = self.establish_connection_rx.recv() => { - match establish_connection { - Some(ExternConnection::Establish { peer, tx, is_gw }) => { - self.start_outbound_connection(peer, tx, is_gw).await; - } - Some(ExternConnection::Dropped { peer }) => { - self.connected.remove(&peer.addr); - self.outbound_messages.remove(&peer.addr); - self.connecting.remove(&peer.addr); - } - Some(ExternConnection::DropConnectionByAddr(addr)) => { - self.connected.remove(&addr); - self.outbound_messages.remove(&addr); - self.connecting.remove(&addr); - } - None => return Err(HandshakeError::ChannelClosed), - } - } - } - } - } - - async fn forward_transient_connection( - &mut self, - conn: &mut PeerConnection, - transaction: &mut TransientConnection, - ) -> Result { - let mut nw_bridge = ForwardPeerMessage { - msg: parking_lot::Mutex::new(None), - }; - - let joiner_loc = if let Some(own_loc) = self.this_location { - own_loc - } else { - Location::from_address(&conn.remote_addr()) - }; - let joiner_pk_loc = PeerKeyLocation { - peer: transaction.joiner.clone(), - location: Some(joiner_loc), - }; - let my_peer_id = self.connection_manager.own_location(); - transaction - .skip_connections - .insert(transaction.joiner.clone()); - transaction.skip_forwards.insert(transaction.joiner.clone()); - transaction.skip_connections.insert(my_peer_id.peer.clone()); - transaction.skip_forwards.insert(my_peer_id.peer.clone()); - - let forward_info = ForwardParams { - left_htl: transaction.hops_to_live, - max_htl: transaction.max_hops_to_live, - accepted: true, - skip_connections: transaction.skip_connections.clone(), - skip_forwards: transaction.skip_forwards.clone(), - req_peer: my_peer_id.clone(), - joiner: joiner_pk_loc.clone(), - is_gateway: self.is_gateway, - }; - - match forward_conn( - transaction.tx, - &self.connection_manager, - self.router.clone(), - &mut nw_bridge, - forward_info, - ) - .await - { - Ok(Some(conn_state)) => { - let ConnectState::AwaitingConnectivity(info) = conn_state else { - unreachable!("forward_conn should return AwaitingConnectivity if successful") - }; - - // Check if we have a forward message (forwarding) or not (direct acceptance) - if let Some((forward_target, msg)) = nw_bridge.msg.into_inner() { - Ok(ForwardResult::Forward(forward_target, msg, info)) - } else if info.is_bootstrap_acceptance { - // Gateway bootstrap case: connection should be registered immediately - // This bypasses the normal CheckConnectivity flow. See forward_conn() - // bootstrap logic in connect.rs for full explanation. - Ok(ForwardResult::BootstrapAccepted(info)) - } else { - // Normal direct acceptance - will wait for CheckConnectivity - Ok(ForwardResult::DirectlyAccepted(info)) - } - } - Ok(None) => { - tracing::debug!(at=?conn.my_address(), from=%conn.remote_addr(), "Rejecting connection, no peers found to forward"); - // No peer to forward to, reject the connection - let reject_msg = NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Response { - id: transaction.tx, - sender: my_peer_id.clone(), - target: joiner_pk_loc, - msg: ConnectResponse::AcceptedBy { - accepted: false, - acceptor: my_peer_id, - joiner: transaction.joiner.clone(), - }, - })); - conn.send(reject_msg).await?; - tracing::debug!(at=?conn.my_address(), from=%conn.remote_addr(), "Connection rejected"); - Ok(ForwardResult::Rejected) - } - Err(_) => Err(HandshakeError::ConnectionClosed(conn.remote_addr())), - } - } - /// Tracks a new inbound connection and sets up message handling for it. fn track_inbound_connection(&mut self, conn: PeerConnection) { let (outbound_msg_sender, outbound_msg_recv) = mpsc::channel(100); @@ -859,37 +378,733 @@ impl HandshakeHandler { .boxed(); self.ongoing_outbound_connections.push(f); } +} - /// Waits for confirmation from a gateway after establishing a connection. - async fn wait_for_gw_confirmation( - &mut self, - gw_peer_id: PeerId, - conn: PeerConnection, - max_hops_to_live: usize, - ) -> Result<()> { - let tx = *self - .connecting - .get(&gw_peer_id.addr) - .ok_or_else(|| HandshakeError::ConnectionClosed(conn.remote_addr()))?; - let this_peer = self.connection_manager.own_location().peer; - tracing::debug!(at=?conn.my_address(), %this_peer.addr, from=%conn.remote_addr(), remote_addr = %gw_peer_id, "Waiting for confirmation from gw"); - self.ongoing_outbound_connections.push( - wait_for_gw_confirmation( - (this_peer, self.this_location), - AcceptedTracker { - gw_peer: gw_peer_id.into(), - gw_conn: conn, - gw_accepted: false, - gw_accepted_processed: false, - remaining_checks: max_hops_to_live, - accepted: 0, - total_checks: max_hops_to_live, - tx, - }, - ) - .boxed(), - ); - Ok(()) +/// Stream wrapper that takes ownership of HandshakeHandler and implements Stream properly. +/// This converts the event loop logic from wait_for_events into a proper Stream implementation. +pub(super) struct HandshakeEventStream { + handler: HandshakeHandler, +} + +impl HandshakeEventStream { + pub fn new(handler: HandshakeHandler) -> Self { + Self { handler } + } +} + +impl futures::stream::Stream for HandshakeEventStream { + type Item = Result; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + use std::task::Poll; + + let handler = &mut self.handler; + + // Main event loop - mirrors the original `loop { tokio::select! {...} }` structure + // We loop internally to handle "continue" cases without returning to the executor + loop { + tracing::trace!( + "HandshakeEventStream::poll_next iteration - unconfirmed: {}, ongoing_outbound: {}", + handler.unconfirmed_inbound_connections.len(), + handler.ongoing_outbound_connections.len() + ); + + // Priority 1: Handle new inbound connections + // Poll the future and extract the result, then drop it before using handler again + let inbound_result = { + let inbound_fut = handler.inbound_conn_handler.next_connection(); + tokio::pin!(inbound_fut); + inbound_fut.poll(cx) + }; // inbound_fut dropped here + + match inbound_result { + Poll::Ready(Some(conn)) => { + tracing::debug!(from=%conn.remote_addr(), "New inbound connection"); + handler.track_inbound_connection(conn); + // This was a `continue` in the loop - loop again to re-poll all priorities + continue; + } + Poll::Ready(None) => { + return Poll::Ready(Some(Err(HandshakeError::ChannelClosed))); + } + Poll::Pending => {} + } + + // Priority 2: Process outbound connection attempts + if !handler.ongoing_outbound_connections.is_empty() { + match std::pin::Pin::new(&mut handler.ongoing_outbound_connections).poll_next(cx) { + Poll::Ready(Some(outbound_result)) => { + // Handle the result - may return event or continue + let result = handle_outbound_result(handler, outbound_result, cx); + if let Some(event) = result { + return Poll::Ready(Some(event)); + } else { + // Was a continue case - loop again to re-poll all priorities + continue; + } + } + Poll::Ready(None) => { + // FuturesUnordered is now empty - this is normal, just continue to next channel + } + Poll::Pending => {} + } + } + + // Priority 3: Handle unconfirmed inbound connections (for gateways) + if !handler.unconfirmed_inbound_connections.is_empty() { + match std::pin::Pin::new(&mut handler.unconfirmed_inbound_connections).poll_next(cx) + { + Poll::Ready(Some(res)) => { + tracing::debug!("Processing unconfirmed inbound connection"); + let (event, outbound_sender) = match res { + Ok(v) => v, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + tracing::debug!("Unconfirmed connection event: {:?}", event); + let result = + handle_unconfirmed_inbound(handler, event, outbound_sender, cx); + if let Some(event) = result { + return Poll::Ready(Some(event)); + } else { + // Was a continue case - loop again to re-poll all priorities + continue; + } + } + Poll::Ready(None) => { + // FuturesUnordered is now empty - this is normal, just continue to next channel + } + Poll::Pending => {} + } + } + + // Priority 4: Handle outbound message requests + match handler.pending_msg_rx.poll_recv(cx) { + Poll::Ready(Some((addr, msg))) => { + // Call handler.outbound() - this returns Option + // Scope to drop the future borrow immediately + let result = { + let outbound_fut = handler.outbound(addr, msg); + tokio::pin!(outbound_fut); + outbound_fut.poll(cx) + }; + match result { + Poll::Ready(Some(event)) => { + return Poll::Ready(Some(Ok(event))); + } + Poll::Ready(None) => { + // outbound() returned None - continue to re-poll all priorities + continue; + } + Poll::Pending => { + // The outbound future is pending - continue to next priority + } + } + } + Poll::Ready(None) => { + return Poll::Ready(Some(Err(HandshakeError::ChannelClosed))); + } + Poll::Pending => {} + } + + // Priority 5: Handle connection establishment requests + match handler.establish_connection_rx.poll_recv(cx) { + Poll::Ready(Some(ExternConnection::Establish { peer, tx, is_gw })) => { + // Start outbound connection - call the async method + // Scope to drop the future borrow immediately + let _ = { + let start_fut = handler.start_outbound_connection(peer, tx, is_gw); + tokio::pin!(start_fut); + start_fut.poll(cx) + }; + // Poll it immediately - it will push futures to ongoing_outbound_connections + // Then loop again to re-poll all priorities (ongoing_outbound_connections might have work) + continue; + } + Poll::Ready(Some(ExternConnection::Dropped { peer })) => { + handler.connected.remove(&peer.addr); + handler.outbound_messages.remove(&peer.addr); + handler.connecting.remove(&peer.addr); + // Continue to re-poll all priorities + continue; + } + Poll::Ready(Some(ExternConnection::DropConnectionByAddr(addr))) => { + handler.connected.remove(&addr); + handler.outbound_messages.remove(&addr); + handler.connecting.remove(&addr); + // Continue to re-poll all priorities + continue; + } + Poll::Ready(None) => { + return Poll::Ready(Some(Err(HandshakeError::ChannelClosed))); + } + Poll::Pending => {} + } + + // All channels are pending - return Pending and wait to be woken + return Poll::Pending; + } // end of loop + } +} + +// Helper to handle outbound connection results +// Returns Some(event) if should return an event, None if should continue +fn handle_outbound_result( + handler: &mut HandshakeHandler, + result: OutboundConnResult, + cx: &mut std::task::Context<'_>, +) -> Option> { + match result { + Ok(InternalEvent::OutboundConnEstablished(peer_id, connection)) => { + tracing::info!(at=?connection.my_address(), from=%connection.remote_addr(), "Outbound connection successful"); + Some(Ok(Event::OutboundConnectionSuccessful { + peer_id, + connection, + })) + } + Ok(InternalEvent::OutboundGwConnEstablished(id, connection)) => { + tracing::info!(at=?connection.my_address(), from=%connection.remote_addr(), "Outbound gateway connection successful"); + if let Some(addr) = connection.my_address() { + tracing::debug!(%addr, "Attempting setting own peer key"); + handler.connection_manager.try_set_peer_key(addr); + + if let Some(ref peer_ready) = handler.peer_ready { + peer_ready.store(true, std::sync::atomic::Ordering::SeqCst); + tracing::info!("Peer initialization complete: peer_ready set to true, client operations now enabled"); + } + + if handler.this_location.is_none() { + handler + .connection_manager + .update_location(Some(Location::from_address(&addr))); + } + } + tracing::debug!(at=?connection.my_address(), from=%connection.remote_addr(), "Outbound connection to gw successful"); + + // Call wait_for_gw_confirmation - it pushes a future to ongoing_outbound_connections + let tx = match handler.connecting.get(&id.addr) { + Some(t) => *t, + None => { + tracing::error!("Transaction not found for gateway connection"); + return Some(Err(HandshakeError::ConnectionClosed( + connection.remote_addr(), + ))); + } + }; + let this_peer = handler.connection_manager.own_location().peer; + tracing::debug!(at=?connection.my_address(), %this_peer.addr, from=%connection.remote_addr(), remote_addr = %id, "Waiting for confirmation from gw"); + handler.ongoing_outbound_connections.push( + wait_for_gw_confirmation( + (this_peer, handler.this_location), + AcceptedTracker { + gw_peer: id.into(), + gw_conn: connection, + gw_accepted: false, + gw_accepted_processed: false, + remaining_checks: Ring::DEFAULT_MAX_HOPS_TO_LIVE, + accepted: 0, + total_checks: Ring::DEFAULT_MAX_HOPS_TO_LIVE, + tx, + }, + ) + .boxed(), + ); + None // Continue + } + Ok(InternalEvent::FinishedOutboundConnProcess(tracker)) => { + handler.connecting.remove(&tracker.gw_peer.peer.addr); + tracing::debug!(at=?tracker.gw_conn.my_address(), gw=%tracker.gw_conn.remote_addr(), "Done checking, connection not accepted by gw, dropping connection"); + Some(Ok(Event::OutboundGatewayConnectionRejected { + peer_id: tracker.gw_peer.peer, + })) + } + Ok(InternalEvent::OutboundGwConnConfirmed(tracker)) => { + tracing::debug!(at=?tracker.gw_conn.my_address(), from=%tracker.gw_conn.remote_addr(), "Outbound connection to gw confirmed"); + handler.connected.insert(tracker.gw_conn.remote_addr()); + handler.connecting.remove(&tracker.gw_conn.remote_addr()); + Some(Ok(Event::OutboundGatewayConnectionSuccessful { + peer_id: tracker.gw_peer.peer, + connection: tracker.gw_conn, + remaining_checks: tracker.remaining_checks, + })) + } + Ok(InternalEvent::NextCheck(tracker)) => { + handler + .ongoing_outbound_connections + .push(check_remaining_hops(tracker).boxed()); + None // Continue + } + Ok(InternalEvent::RemoteConnectionAttempt { remote, tracker }) => { + debug_assert!(!tracker.gw_accepted); + tracing::debug!( + at=?tracker.gw_conn.my_address(), + gw=%tracker.gw_conn.remote_addr(), + "Attempting remote connection to {remote}" + ); + + // Start outbound connection - poll it immediately to start the work + let _result = { + let start_fut = + handler.start_outbound_connection(remote.clone(), tracker.tx, false); + tokio::pin!(start_fut); + start_fut.poll(cx) + }; + + // Whether it completes or pends, push check_remaining_hops + let current_span = tracing::Span::current(); + let checking_hops_span = tracing::info_span!(parent: current_span, "checking_hops"); + handler.ongoing_outbound_connections.push( + check_remaining_hops(tracker) + .instrument(checking_hops_span) + .boxed(), + ); + None // Continue + } + Ok(InternalEvent::DropInboundConnection(addr)) => { + handler.connecting.remove(&addr); + handler.outbound_messages.remove(&addr); + None // Continue + } + Err((peer_id, error)) => { + tracing::debug!(from=%peer_id.addr, "Outbound connection failed: {error}"); + handler.connecting.remove(&peer_id.addr); + handler.outbound_messages.remove(&peer_id.addr); + handler.connection_manager.prune_alive_connection(&peer_id); + Some(Ok(Event::OutboundConnectionFailed { peer_id, error })) + } + Ok(other) => { + tracing::error!("Unexpected event: {other:?}"); + None // Continue + } + } +} + +// Helper to handle unconfirmed inbound events +// Returns Some(event) if should return, None if should continue +fn handle_unconfirmed_inbound( + handler: &mut HandshakeHandler, + event: InternalEvent, + outbound_sender: PeerOutboundMessage, + _cx: &mut std::task::Context<'_>, +) -> Option> { + match event { + InternalEvent::InboundGwJoinRequest(req) => { + // This requires async work - spawn it as a future + let conn_manager = handler.connection_manager.clone(); + let router = handler.router.clone(); + let this_location = handler.this_location; + let is_gateway = handler.is_gateway; + + // Spawn the async handling + let fut = handle_inbound_gw_join_request( + req, + conn_manager, + router, + this_location, + is_gateway, + outbound_sender, + ); + + handler.unconfirmed_inbound_connections.push(fut.boxed()); + None + } + InternalEvent::InboundConnectionAccepted { + id, + conn, + joiner, + op, + forward_info, + is_bootstrap, + } => { + tracing::debug!(%joiner, "Inbound connection accepted"); + // The outbound sender was already stored in outbound_messages by track_inbound_connection + // We just need to return the event + Some(Ok(Event::InboundConnection { + id, + conn, + joiner, + op, + forward_info, + is_bootstrap, + })) + } + InternalEvent::InboundConnectionRejected { peer_id, remote } => { + tracing::debug!(%peer_id, %remote, "Inbound connection rejected"); + handler.outbound_messages.remove(&remote); + handler.connecting.remove(&remote); + Some(Ok(Event::InboundConnectionRejected { peer_id })) + } + InternalEvent::TransientForward { + conn, + tx, + info, + target, + forward_to, + msg, + } => { + tracing::debug!(%target, %forward_to, "Transient forward"); + // Save transaction ID before moving tx + let transaction_id = tx.tx; + // Push gw_transient_peer_conn future to monitor this connection + handler + .unconfirmed_inbound_connections + .push(gw_transient_peer_conn(conn, outbound_sender, tx, info).boxed()); + Some(Ok(Event::TransientForwardTransaction { + target, + tx: transaction_id, + forward_to, + msg, + })) + } + InternalEvent::DropInboundConnection(addr) => { + tracing::debug!(%addr, "Dropping inbound connection"); + handler.outbound_messages.remove(&addr); + None + } + _ => { + tracing::warn!("Unhandled unconfirmed inbound event: {:?}", event); + None + } + } +} + +// Async function to handle InboundGwJoinRequest +async fn handle_inbound_gw_join_request( + mut req: InboundGwJoinRequest, + conn_manager: ConnectionManager, + router: Arc>, + this_location: Option, + is_gateway: bool, + outbound_sender: PeerOutboundMessage, +) -> Result<(InternalEvent, PeerOutboundMessage), HandshakeError> { + let location = if let Some((_, other)) = this_location.zip(req.location) { + other + } else { + Location::from_address(&req.conn.remote_addr()) + }; + + let should_accept = conn_manager.should_accept(location, &req.joiner); + let can_accept = should_accept && (is_gateway || conn_manager.num_connections() > 0); + + if can_accept { + // Accepted connection path: Send acceptance message, then forward + let accepted_msg = NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Response { + id: req.id, + sender: conn_manager.own_location(), + target: PeerKeyLocation { + peer: req.joiner.clone(), + location: Some(location), + }, + msg: ConnectResponse::AcceptedBy { + accepted: true, + acceptor: conn_manager.own_location(), + joiner: req.joiner.clone(), + }, + })); + + tracing::debug!(at=?req.conn.my_address(), from=%req.conn.remote_addr(), "Accepting connection"); + + if let Err(e) = req.conn.send(accepted_msg).await { + tracing::error!(%e, "Failed to send accepted message from gw, pruning reserved connection"); + conn_manager.prune_in_transit_connection(&req.joiner); + return Err(e.into()); + } + + let InboundGwJoinRequest { + conn, + id, + hops_to_live, + max_hops_to_live, + skip_connections, + skip_forwards, + joiner, + .. + } = req; + + // Forward the connection + let mut nw_bridge = ForwardPeerMessage { + msg: parking_lot::Mutex::new(None), + }; + + let my_peer_id = conn_manager.own_location(); + let joiner_pk_loc = PeerKeyLocation { + peer: joiner.clone(), + location: Some(location), + }; + + let mut skip_connections = skip_connections.clone(); + let mut skip_forwards = skip_forwards.clone(); + skip_connections.insert(my_peer_id.peer.clone()); + skip_forwards.insert(my_peer_id.peer.clone()); + + let forward_info = ForwardParams { + left_htl: hops_to_live, + max_htl: max_hops_to_live, + accepted: true, + skip_connections, + skip_forwards, + req_peer: my_peer_id.clone(), + joiner: joiner_pk_loc.clone(), + is_gateway, + }; + + match forward_conn( + id, + &conn_manager, + router.clone(), + &mut nw_bridge, + forward_info, + ) + .await + { + Err(err) => { + tracing::error!(%err, "Error forwarding connection"); + // Continue by returning DropInboundConnection + Ok(( + InternalEvent::DropInboundConnection(conn.remote_addr()), + outbound_sender, + )) + } + Ok(Some(conn_state)) => { + let ConnectState::AwaitingConnectivity(info) = conn_state else { + unreachable!("forward_conn should return AwaitingConnectivity if successful") + }; + + tracing::info!(%id, %joiner, "Creating InboundConnection event"); + + // Check if we have a forward message (forwarding) or not (direct acceptance) + let (op, forward_info_opt, is_bootstrap) = + if let Some((forward_target, msg)) = nw_bridge.msg.into_inner() { + ( + Some(Box::new(ConnectOp::new( + id, + Some(ConnectState::AwaitingConnectivity(info)), + None, + None, + ))), + Some(Box::new(ForwardInfo { + target: forward_target, + msg, + })), + false, + ) + } else if info.is_bootstrap_acceptance { + // Gateway bootstrap case: connection should be registered immediately + ( + Some(Box::new(ConnectOp::new( + id, + Some(ConnectState::AwaitingConnectivity(info)), + None, + None, + ))), + None, + true, + ) + } else { + // Normal direct acceptance - will wait for CheckConnectivity + ( + Some(Box::new(ConnectOp::new( + id, + Some(ConnectState::AwaitingConnectivity(info)), + None, + None, + ))), + None, + false, + ) + }; + + Ok(( + InternalEvent::InboundConnectionAccepted { + id, + conn, + joiner, + op, + forward_info: forward_info_opt, + is_bootstrap, + }, + outbound_sender, + )) + } + Ok(None) => { + // No forwarding target found - return event with op: None to signal rejection + // This matches original behavior where forward_result (None, _) returns Event with op: None + Ok(( + InternalEvent::InboundConnectionAccepted { + id, + conn, + joiner, + op: None, // Signals rejection/no forwarding possible + forward_info: None, + is_bootstrap: false, + }, + outbound_sender, + )) + } + } + } else { + // Transient connection path: Try to forward without accepting + // If should_accept was true but we can't actually accept (non-gateway with 0 connections), + // we need to clean up the reserved connection + if should_accept && !can_accept { + conn_manager.prune_in_transit_connection(&req.joiner); + tracing::debug!( + "Non-gateway with 0 connections cannot accept connection from {:?}", + req.joiner + ); + } + + let InboundGwJoinRequest { + mut conn, + id, + hops_to_live, + max_hops_to_live, + skip_connections, + skip_forwards, + joiner, + .. + } = req; + + let remote = conn.remote_addr(); + tracing::debug!(at=?conn.my_address(), from=%remote, "Transient connection"); + + // Try to forward the connection without accepting it + let joiner_loc = this_location.unwrap_or_else(|| Location::from_address(&remote)); + let joiner_pk_loc = PeerKeyLocation { + peer: joiner.clone(), + location: Some(joiner_loc), + }; + let my_peer_id = conn_manager.own_location(); + + let mut skip_connections_updated = skip_connections.clone(); + let mut skip_forwards_updated = skip_forwards.clone(); + skip_connections_updated.insert(joiner.clone()); + skip_forwards_updated.insert(joiner.clone()); + skip_connections_updated.insert(my_peer_id.peer.clone()); + skip_forwards_updated.insert(my_peer_id.peer.clone()); + + let forward_info = ForwardParams { + left_htl: hops_to_live, + max_htl: max_hops_to_live, + accepted: true, + skip_connections: skip_connections_updated, + skip_forwards: skip_forwards_updated, + req_peer: my_peer_id.clone(), + joiner: joiner_pk_loc.clone(), + is_gateway, + }; + + let mut nw_bridge = ForwardPeerMessage { + msg: parking_lot::Mutex::new(None), + }; + + match forward_conn( + id, + &conn_manager, + router.clone(), + &mut nw_bridge, + forward_info, + ) + .await + { + Ok(Some(conn_state)) => { + let ConnectState::AwaitingConnectivity(info) = conn_state else { + unreachable!("forward_conn should return AwaitingConnectivity if successful") + }; + + // Check the forwarding result + if let Some((forward_target, msg)) = nw_bridge.msg.into_inner() { + // Successfully forwarding to another peer + // Create a TransientConnection to track this + let tx = TransientConnection { + tx: id, + joiner: joiner.clone(), + }; + + // Push gw_transient_peer_conn future to monitor this connection + Ok(( + InternalEvent::TransientForward { + conn, + tx, + info, + target: remote, + forward_to: forward_target, + msg: Box::new(msg), + }, + outbound_sender, + )) + } else if info.is_bootstrap_acceptance { + // Bootstrap acceptance - accept it directly even though we didn't send acceptance yet + Ok(( + InternalEvent::InboundConnectionAccepted { + id, + conn, + joiner, + op: Some(Box::new(ConnectOp::new( + id, + Some(ConnectState::AwaitingConnectivity(info)), + None, + None, + ))), + forward_info: None, + is_bootstrap: true, + }, + outbound_sender, + )) + } else { + // Direct acceptance without forwarding - shouldn't happen for transient + // Clean up and reject + conn_manager.prune_in_transit_connection(&joiner); + Ok(( + InternalEvent::InboundConnectionRejected { + peer_id: joiner, + remote, + }, + outbound_sender, + )) + } + } + Ok(None) => { + // No peer to forward to - send rejection message + tracing::debug!(at=?conn.my_address(), from=%conn.remote_addr(), "Rejecting connection, no peers found to forward"); + let reject_msg = NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Response { + id, + sender: my_peer_id.clone(), + target: joiner_pk_loc, + msg: ConnectResponse::AcceptedBy { + accepted: false, + acceptor: my_peer_id, + joiner: joiner.clone(), + }, + })); + + if let Err(e) = conn.send(reject_msg).await { + tracing::error!(%e, "Failed to send rejection message"); + return Err(e.into()); + } + + // Clean up and reject + conn_manager.prune_in_transit_connection(&joiner); + Ok(( + InternalEvent::InboundConnectionRejected { + peer_id: joiner, + remote, + }, + outbound_sender, + )) + } + Err(e) => { + tracing::error!(from=%remote, "Error forwarding transient connection: {e}"); + // Drop the connection and clean up + conn_manager.prune_in_transit_connection(&joiner); + Ok(( + InternalEvent::DropInboundConnection(remote), + outbound_sender, + )) + } + } } } @@ -949,9 +1164,31 @@ enum InternalEvent { }, NextCheck(AcceptedTracker), FinishedOutboundConnProcess(AcceptedTracker), + // New variants for forwarding results + InboundConnectionAccepted { + id: Transaction, + conn: PeerConnection, + joiner: PeerId, + op: Option>, + forward_info: Option>, + is_bootstrap: bool, + }, + InboundConnectionRejected { + peer_id: PeerId, + remote: SocketAddr, + }, + TransientForward { + conn: PeerConnection, + tx: TransientConnection, + info: ConnectivityInfo, + target: SocketAddr, + forward_to: PeerId, + msg: Box, + }, } #[repr(transparent)] +#[derive(Debug)] struct PeerOutboundMessage(mpsc::Receiver); #[derive(Debug)] @@ -1285,13 +1522,19 @@ async fn gw_transient_peer_conn( } } +/// Tracks a transient connection that is being forwarded through this gateway. +/// This struct is only used by `gw_transient_peer_conn` to identify and validate +/// drop connection messages from the joiner. +/// +/// Note: In the original implementation, this struct also contained `max_hops_to_live`, +/// `hops_to_live`, `skip_connections`, and `skip_forwards` fields that were used by +/// the `forward_transient_connection` method. In the stream-based refactoring, these +/// values are used directly from the `InboundGwJoinRequest` when calling `forward_conn`, +/// so they don't need to be stored in this struct. +#[derive(Debug)] struct TransientConnection { tx: Transaction, joiner: PeerId, - max_hops_to_live: usize, - hops_to_live: usize, - skip_connections: HashSet, - skip_forwards: HashSet, } impl TransientConnection { @@ -1599,10 +1842,21 @@ mod tests { open_connection } + // ============================================================================ + // Stream-based tests for HandshakeEventStream + // ============================================================================ + + /// Helper to get the next event from a HandshakeEventStream + async fn next_stream_event(stream: &mut HandshakeEventStream) -> Result { + use futures::StreamExt; + stream.next().await.ok_or(HandshakeError::ChannelClosed)? + } + #[tokio::test] - async fn test_gateway_inbound_conn_success() -> anyhow::Result<()> { + async fn test_stream_gateway_inbound_conn_success() -> anyhow::Result<()> { let addr: SocketAddr = ([127, 0, 0, 1], 10000).into(); - let (mut handler, mut test) = config_handler(addr, None, true); + let (handler, mut test) = config_handler(addr, None, true); + let mut stream = HandshakeEventStream::new(handler); let remote_addr = ([127, 0, 0, 1], 10001).into(); let test_controller = async { @@ -1616,7 +1870,8 @@ mod tests { let gw_inbound = async { let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; + tokio::time::timeout(Duration::from_secs(15), next_stream_event(&mut stream)) + .await??; match event { Event::InboundConnection { conn, .. } => { assert_eq!(conn.remote_addr(), remote_addr); @@ -1630,71 +1885,70 @@ mod tests { } #[tokio::test] - async fn test_gateway_inbound_conn_rejected() -> anyhow::Result<()> { + async fn test_stream_gateway_inbound_conn_rejected() -> anyhow::Result<()> { let addr: SocketAddr = ([127, 0, 0, 1], 10000).into(); - let existing_remote_addr = ([127, 0, 0, 1], 10001).into(); - let remote_peer_loc = PeerKeyLocation { - peer: PeerId::new( - existing_remote_addr, - TransportKeypair::new().public().clone(), - ), - location: Some(Location::from_address(&existing_remote_addr)), - }; - let existing_conn = - Connection::new(remote_peer_loc.peer, remote_peer_loc.location.unwrap()); - - let (mut handler, mut test) = config_handler(addr, Some(vec![existing_conn]), true); - - // Configure the handler to reject connections by setting max_connections to 1 - handler.connection_manager.max_connections = 1; - handler.connection_manager.min_connections = 1; - - let remote_addr = ([127, 0, 0, 1], 10002).into(); + let (handler, mut test) = config_handler(addr, None, true); + let mut stream = HandshakeEventStream::new(handler); + let remote_addr = ([127, 0, 0, 1], 10001).into(); + let remote_pub_key = TransportKeypair::new().public().clone(); let test_controller = async { - let pub_key = TransportKeypair::new().public().clone(); test.transport.new_conn(remote_addr).await; - // Put hops_to_live to 0 to avoid forwarding test.transport - .establish_inbound_conn(remote_addr, pub_key, Some(0)) + .establish_inbound_conn(remote_addr, remote_pub_key.clone(), None) .await; - let msg = test.transport.recv_outbound_msg().await?; - tracing::debug!("Received outbound message: {:?}", msg); - assert!( - matches!(msg, NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Response { - msg: ConnectResponse::AcceptedBy { accepted, .. }, - .. - })) if !accepted) - ); + + // Reject the connection + let sender_key = TransportKeypair::new().public().clone(); + let acceptor_key = TransportKeypair::new().public().clone(); + let joiner_key = TransportKeypair::new().public().clone(); + let response = NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Response { + id: Transaction::new::(), + sender: PeerKeyLocation { + peer: PeerId::new(addr, sender_key), + location: Some(Location::random()), + }, + target: PeerKeyLocation { + peer: PeerId::new(remote_addr, remote_pub_key), + location: Some(Location::random()), + }, + msg: ConnectResponse::AcceptedBy { + accepted: false, + acceptor: PeerKeyLocation { + peer: PeerId::new(addr, acceptor_key), + location: Some(Location::random()), + }, + joiner: PeerId::new(remote_addr, joiner_key), + }, + })); + + test.transport.inbound_msg(remote_addr, response).await; Ok::<_, anyhow::Error>(()) }; let gw_inbound = async { + // First event: InboundConnection (may be accepted or rejected depending on routing) let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; - match event { - Event::InboundConnectionRejected { peer_id } => { - assert_eq!(peer_id.addr, remote_addr); - Ok(()) - } - other => Err(anyhow!("Unexpected event: {:?}", other)), - } + tokio::time::timeout(Duration::from_secs(15), next_stream_event(&mut stream)) + .await??; + tracing::info!("Received event: {:?}", event); + Ok(()) }; - futures::try_join!(test_controller, gw_inbound)?; Ok(()) } - #[test_log::test(tokio::test)] - async fn test_peer_to_gw_outbound_conn() -> anyhow::Result<()> { - let addr = ([127, 0, 0, 1], 10000).into(); - let (mut handler, mut test) = config_handler(addr, None, false); + #[tokio::test] + async fn test_stream_peer_to_gw_outbound_conn() -> anyhow::Result<()> { + let addr: SocketAddr = ([127, 0, 0, 1], 10001).into(); + let (handler, mut test) = config_handler(addr, None, false); + let mut stream = HandshakeEventStream::new(handler); let joiner_key = TransportKeypair::new(); let pub_key = joiner_key.public().clone(); let id = Transaction::new::(); + let remote_addr: SocketAddr = ([127, 0, 0, 2], 10002).into(); - let remote_addr: SocketAddr = ([127, 0, 0, 1], 10001).into(); let test_controller = async { let open_connection = start_conn(&mut test, remote_addr, pub_key.clone(), id, true).await; @@ -1702,6 +1956,8 @@ mod tests { .new_outbound_conn(remote_addr, open_connection) .await; tracing::debug!("Outbound connection established"); + + // Wait for and respond to StartJoinReq let msg = test.transport.recv_outbound_msg().await?; let msg = match msg { NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Request { @@ -1736,169 +1992,78 @@ mod tests { Ok::<_, anyhow::Error>(()) }; - let peer_inbound = async { + let peer_outbound = async { let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; + tokio::time::timeout(Duration::from_secs(15), next_stream_event(&mut stream)) + .await??; match event { - Event::OutboundGatewayConnectionSuccessful { peer_id, .. } => { + Event::OutboundGatewayConnectionSuccessful { + peer_id, + connection, + .. + } => { assert_eq!(peer_id.addr, remote_addr); assert_eq!(peer_id.pub_key, pub_key); + drop(connection); Ok(()) } other => bail!("Unexpected event: {:?}", other), } }; - futures::try_join!(test_controller, peer_inbound)?; - Ok(()) - } - - #[tokio::test] - async fn test_peer_to_gw_outbound_conn_failed() -> anyhow::Result<()> { - let addr = ([127, 0, 0, 1], 10000).into(); - let (mut handler, mut test) = config_handler(addr, None, false); - - let joiner_key = TransportKeypair::new(); - let pub_key = joiner_key.public().clone(); - let id = Transaction::new::(); - - let test_controller = async { - let open_connection = start_conn(&mut test, addr, pub_key.clone(), id, true).await; - open_connection - .send(Err(TransportError::ConnectionEstablishmentFailure { - cause: "Connection refused".into(), - })) - .map_err(|_| anyhow!("Failed to send connection"))?; - Ok::<_, anyhow::Error>(()) - }; - let peer_inbound = async { - let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; - match event { - Event::OutboundConnectionFailed { peer_id, error } => { - let addr: SocketAddr = ([127, 0, 0, 1], 10000).into(); - assert_eq!(peer_id.addr, addr); - assert_eq!(peer_id.pub_key, pub_key); - assert!(matches!( - error, - HandshakeError::TransportError( - TransportError::ConnectionEstablishmentFailure { .. } - ) - )); - Ok(()) - } - other => bail!("Unexpected event: {:?}", other), - } - }; - futures::try_join!(test_controller, peer_inbound)?; + futures::try_join!(test_controller, peer_outbound)?; Ok(()) } #[tokio::test] - async fn test_gw_to_peer_outbound_conn_forwarded() -> anyhow::Result<()> { - let gw_addr: SocketAddr = ([127, 0, 0, 1], 10000).into(); - let peer_addr: SocketAddr = ([127, 0, 0, 1], 10001).into(); - let joiner_addr: SocketAddr = ([127, 0, 0, 1], 10002).into(); - - let (mut gw_handler, mut gw_test) = config_handler(gw_addr, None, true); - - // the gw only will accept one connection - gw_handler.connection_manager.max_connections = 1; - gw_handler.connection_manager.min_connections = 1; + async fn test_stream_peer_to_peer_outbound_conn_succeeded() -> anyhow::Result<()> { + let addr: SocketAddr = ([127, 0, 0, 1], 10001).into(); + let (handler, mut test) = config_handler(addr, None, false); + let mut stream = HandshakeEventStream::new(handler); let peer_key = TransportKeypair::new(); - let joiner_key = TransportKeypair::new(); - let peer_pub_key = peer_key.public().clone(); - let joiner_pub_key = joiner_key.public().clone(); - - let peer_peer_id = PeerId::new(peer_addr, peer_pub_key.clone()); - - let gw_test_controller = async { - // the connection to the gw with the third-party peer is established first - gw_test.transport.new_conn(peer_addr).await; - gw_test - .transport - .establish_inbound_conn(peer_addr, peer_pub_key.clone(), None) - .await; + let peer_addr = ([127, 0, 0, 2], 10002).into(); - // Wait longer to ensure the peer connection is fully processed - tokio::time::sleep(Duration::from_millis(500)).await; + let tx = Transaction::new::(); - // the joiner attempts to connect to the gw, but since it's out of connections - // it will just be a transient connection - gw_test.transport.new_conn(joiner_addr).await; - gw_test - .transport - .establish_inbound_conn(joiner_addr, joiner_pub_key, None) + let test_controller = async { + let open_connection = + start_conn(&mut test, peer_addr, peer_pub_key.clone(), tx, false).await; + test.transport + .new_outbound_conn(peer_addr, open_connection) .await; - // Give some time for the events to be processed - tokio::time::sleep(Duration::from_millis(100)).await; - - // TODO: maybe simulate forwarding back all expected responses Ok::<_, anyhow::Error>(()) }; - let peer_and_gw = async { - let mut third_party = None; - loop { - let event = - tokio::time::timeout(Duration::from_secs(15), gw_handler.wait_for_events()) - .await??; - match event { - Event::InboundConnection { - conn: first_peer_conn, - joiner: third_party_peer, - .. - } => { - assert_eq!(third_party_peer.pub_key, peer_pub_key); - assert_eq!(first_peer_conn.remote_addr(), peer_addr); - third_party = Some(third_party_peer); - gw_handler.connection_manager.add_connection( - Location::from_address(&peer_addr), - peer_peer_id.clone(), - false, - ); - } - Event::TransientForwardTransaction { - target, - forward_to, - msg, - .. - } => { - // transient connection created, and forwarded a request to join to the third-party peer - assert_eq!(target, joiner_addr); - assert_eq!(forward_to.pub_key, peer_pub_key); - assert_eq!(forward_to.addr, peer_peer_id.addr); - assert!(matches!( - &*msg, - NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Request { - msg: ConnectRequest::CheckConnectivity { .. }, - .. - })) - )); - break; - } - other => bail!("Unexpected event: {:?}", other), + let peer_inbound = async { + let event = + tokio::time::timeout(Duration::from_secs(15), next_stream_event(&mut stream)) + .await??; + match event { + Event::OutboundConnectionSuccessful { + peer_id, + connection, + } => { + assert_eq!(peer_id.addr, peer_addr); + assert_eq!(peer_id.pub_key, peer_pub_key); + drop(connection); + Ok(()) } + other => bail!("Unexpected event: {:?}", other), } - - assert!(third_party.is_some()); - Ok(()) }; - let result = futures::try_join!(gw_test_controller, peer_and_gw); - result?; + futures::try_join!(test_controller, peer_inbound)?; Ok(()) } - #[ignore = "fix this test"] #[tokio::test(flavor = "multi_thread", worker_threads = 2)] - async fn test_peer_to_gw_outbound_conn_rejected() -> anyhow::Result<()> { - // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE), None); + async fn test_stream_peer_to_gw_outbound_conn_rejected() -> anyhow::Result<()> { let joiner_addr = ([127, 0, 0, 1], 10001).into(); - let (mut handler, mut test) = config_handler(joiner_addr, None, false); + let (handler, mut test) = config_handler(joiner_addr, None, false); + let mut stream = HandshakeEventStream::new(handler); let gw_key = TransportKeypair::new(); let gw_pub_key = gw_key.public().clone(); @@ -2034,7 +2199,7 @@ mod tests { for conn_num in 3..Ring::DEFAULT_MAX_HOPS_TO_LIVE { let conn_num = conn_num + 2; let event = - tokio::time::timeout(Duration::from_secs(60), handler.wait_for_events()) + tokio::time::timeout(Duration::from_secs(60), next_stream_event(&mut stream)) .await .inspect_err(|_| { tracing::error!(%conn_num, "failed while waiting for events"); @@ -2063,177 +2228,4 @@ mod tests { futures::try_join!(test_controller, peer_inbound)?; Ok(()) } - - #[tokio::test] - async fn test_peer_to_gw_outbound_conn_forwarded() -> anyhow::Result<()> { - // crate::config::set_logger(Some(tracing::level_filters::LevelFilter::TRACE), None); - let joiner_addr = ([127, 0, 0, 1], 10001).into(); - let (mut handler, mut test) = config_handler(joiner_addr, None, false); - - let gw_key = TransportKeypair::new(); - let gw_pub_key = gw_key.public().clone(); - let gw_addr = ([127, 0, 0, 1], 10000).into(); - let gw_peer_id = PeerId::new(gw_addr, gw_pub_key.clone()); - let gw_pkloc = PeerKeyLocation { - location: Some(Location::from_address(&gw_peer_id.addr)), - peer: gw_peer_id.clone(), - }; - - let joiner_key = TransportKeypair::new(); - let joiner_pub_key = joiner_key.public().clone(); - let joiner_peer_id = PeerId::new(joiner_addr, joiner_pub_key.clone()); - let joiner_pkloc = PeerKeyLocation { - peer: joiner_peer_id.clone(), - location: Some(Location::from_address(&joiner_peer_id.addr)), - }; - - let tx = Transaction::new::(); - - let test_controller = async { - let open_connection_peer = - start_conn(&mut test, gw_addr, gw_pub_key.clone(), tx, true).await; - test.transport - .new_outbound_conn(gw_addr, open_connection_peer) - .await; - - let msg = test.transport.recv_outbound_msg().await?; - tracing::info!("Received connec request: {:?}", msg); - let NetMessage::V1(NetMessageV1::Connect(ConnectMsg::Request { - id, - msg: ConnectRequest::StartJoinReq { .. }, - .. - })) = msg - else { - panic!("unexpected message"); - }; - assert_eq!(id, tx); - - let initial_join_req = ConnectMsg::Response { - id: tx, - sender: gw_pkloc.clone(), - target: joiner_pkloc.clone(), - msg: ConnectResponse::AcceptedBy { - accepted: true, - acceptor: gw_pkloc.clone(), - joiner: joiner_peer_id.clone(), - }, - }; - test.transport - .inbound_msg( - gw_addr, - NetMessage::V1(NetMessageV1::Connect(initial_join_req)), - ) - .await; - tracing::debug!("Sent initial gw rejected reply"); - Ok::<_, anyhow::Error>(()) - }; - - let peer_inbound = async { - let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; - let _conn = match event { - Event::OutboundGatewayConnectionSuccessful { - peer_id, - connection, - .. - } => { - tracing::info!(%peer_id, "Gateway connection accepted"); - assert_eq!(peer_id.addr, gw_addr); - connection - } - other => bail!("Unexpected event: {:?}", other), - }; - Ok(()) - }; - - futures::try_join!(test_controller, peer_inbound)?; - Ok(()) - } - - #[tokio::test] - async fn test_peer_to_peer_outbound_conn_failed() -> anyhow::Result<()> { - let addr: SocketAddr = ([127, 0, 0, 1], 10001).into(); - let (mut handler, mut test) = config_handler(addr, None, false); - - let peer_key = TransportKeypair::new(); - let peer_pub_key = peer_key.public().clone(); - let peer_addr = ([127, 0, 0, 2], 10002).into(); - - let tx = Transaction::new::(); - - let test_controller = async { - let open_connection = - start_conn(&mut test, peer_addr, peer_pub_key.clone(), tx, false).await; - open_connection - .send(Err(TransportError::ConnectionEstablishmentFailure { - cause: "Connection refused".into(), - })) - .map_err(|_| anyhow!("Failed to send connection"))?; - Ok::<_, anyhow::Error>(()) - }; - - let peer_inbound = async { - let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; - match event { - Event::OutboundConnectionFailed { peer_id, error } => { - assert_eq!(peer_id.addr, peer_addr); - assert_eq!(peer_id.pub_key, peer_pub_key); - assert!(matches!( - error, - HandshakeError::TransportError( - TransportError::ConnectionEstablishmentFailure { .. } - ) - )); - Ok(()) - } - other => bail!("Unexpected event: {:?}", other), - } - }; - - futures::try_join!(test_controller, peer_inbound)?; - Ok(()) - } - - #[tokio::test] - async fn test_peer_to_peer_outbound_conn_succeeded() -> anyhow::Result<()> { - let addr: SocketAddr = ([127, 0, 0, 1], 10001).into(); - let (mut handler, mut test) = config_handler(addr, None, false); - - let peer_key = TransportKeypair::new(); - let peer_pub_key = peer_key.public().clone(); - let peer_addr = ([127, 0, 0, 2], 10002).into(); - - let tx = Transaction::new::(); - - let test_controller = async { - let open_connection = - start_conn(&mut test, peer_addr, peer_pub_key.clone(), tx, false).await; - test.transport - .new_outbound_conn(peer_addr, open_connection) - .await; - - Ok::<_, anyhow::Error>(()) - }; - - let peer_inbound = async { - let event = - tokio::time::timeout(Duration::from_secs(15), handler.wait_for_events()).await??; - match event { - Event::OutboundConnectionSuccessful { - peer_id, - connection, - } => { - assert_eq!(peer_id.addr, peer_addr); - assert_eq!(peer_id.pub_key, peer_pub_key); - drop(connection); - Ok(()) - } - other => bail!("Unexpected event: {:?}", other), - } - }; - - futures::try_join!(test_controller, peer_inbound)?; - Ok(()) - } } diff --git a/crates/core/src/node/network_bridge/p2p_protoc.rs b/crates/core/src/node/network_bridge/p2p_protoc.rs index 9de44f96d..617712817 100644 --- a/crates/core/src/node/network_bridge/p2p_protoc.rs +++ b/crates/core/src/node/network_bridge/p2p_protoc.rs @@ -1,13 +1,9 @@ -use super::{ConnectionError, EventLoopNotificationsReceiver, NetworkBridge}; -use crate::contract::{ContractHandlerEvent, WaitingTransaction}; -use crate::message::{NetMessageV1, QueryResult}; -use crate::node::subscribe::SubscribeMsg; -use crate::ring::Location; use dashmap::DashSet; use either::{Either, Left, Right}; use futures::future::BoxFuture; use futures::stream::FuturesUnordered; use futures::FutureExt; +use futures::StreamExt; use std::convert::Infallible; use std::future::Future; use std::net::{IpAddr, SocketAddr}; @@ -23,13 +19,18 @@ use tokio::sync::oneshot::{self}; use tokio::time::timeout; use tracing::Instrument; +use super::{ConnectionError, EventLoopNotificationsReceiver, NetworkBridge}; +use crate::contract::{ContractHandlerEvent, WaitingTransaction}; +use crate::message::{NetMessageV1, QueryResult}; use crate::node::network_bridge::handshake::{ - Event as HandshakeEvent, ForwardInfo, HandshakeError, HandshakeHandler, HanshakeHandlerMsg, - OutboundMessage, + Event as HandshakeEvent, ForwardInfo, HandshakeError, HandshakeEventStream, HandshakeHandler, + HanshakeHandlerMsg, OutboundMessage, }; use crate::node::network_bridge::priority_select; +use crate::node::subscribe::SubscribeMsg; use crate::node::{MessageProcessor, PeerId}; use crate::operations::{connect::ConnectMsg, get::GetMsg, put::PutMsg, update::UpdateMsg}; +use crate::ring::Location; use crate::transport::{ create_connection_handler, PeerConnection, TransportError, TransportKeypair, }; @@ -166,63 +167,118 @@ impl P2pConnManager { #[allow(clippy::too_many_arguments)] #[tracing::instrument(name = "network_event_listener", fields(peer = %self.bridge.op_manager.ring.connection_manager.pub_key), skip_all)] pub async fn run_event_listener( - mut self, + self, op_manager: Arc, - mut client_wait_for_transaction: ContractHandlerChannel, - mut notification_channel: EventLoopNotificationsReceiver, - mut executor_listener: ExecutorToEventLoopChannel, - mut node_controller: Receiver, + client_wait_for_transaction: ContractHandlerChannel, + notification_channel: EventLoopNotificationsReceiver, + executor_listener: ExecutorToEventLoopChannel, + node_controller: Receiver, ) -> anyhow::Result { + // Destructure self to avoid partial move issues + let P2pConnManager { + gateways, + bridge, + conn_bridge_rx, + event_listener, + connections, + key_pair, + listening_ip, + listening_port, + is_gateway, + this_location, + check_version, + bandwidth_limit, + blocked_addresses, + message_processor, + } = self; + tracing::info!( - %self.listening_port, - %self.listening_ip, - %self.is_gateway, - key = %self.key_pair.public(), + %listening_port, + %listening_ip, + %is_gateway, + key = %key_pair.public(), "Opening network listener - will receive from channel" ); let mut state = EventListenerState::new(); + // Separate peer_connections to allow independent borrowing by the stream + let peer_connections: FuturesUnordered< + BoxFuture<'static, Result>, + > = FuturesUnordered::new(); + let (outbound_conn_handler, inbound_conn_handler) = create_connection_handler::( - self.key_pair.clone(), - self.listening_ip, - self.listening_port, - self.is_gateway, - self.bandwidth_limit, + key_pair.clone(), + listening_ip, + listening_port, + is_gateway, + bandwidth_limit, ) .await?; // For non-gateway peers, pass the peer_ready flag so it can be set after first handshake // For gateways, pass None (they're always ready) - let peer_ready = if !self.is_gateway { - Some(self.bridge.op_manager.peer_ready.clone()) + let peer_ready = if !is_gateway { + Some(bridge.op_manager.peer_ready.clone()) } else { None }; - let (mut handshake_handler, handshake_handler_msg, outbound_message) = - HandshakeHandler::new( - inbound_conn_handler, - outbound_conn_handler.clone(), - self.bridge.op_manager.ring.connection_manager.clone(), - self.bridge.op_manager.ring.router.clone(), - self.this_location, - self.is_gateway, - peer_ready, - ); - - loop { - // Use custom priority select combinator for explicit waker control - // This fixes waker registration issues that occurred with nested tokio::select! - let event = self - .wait_for_event( + let (handshake_handler, handshake_handler_msg, outbound_message) = HandshakeHandler::new( + inbound_conn_handler, + outbound_conn_handler.clone(), + bridge.op_manager.ring.connection_manager.clone(), + bridge.op_manager.ring.router.clone(), + this_location, + is_gateway, + peer_ready, + ); + + // Create priority select stream ONCE by moving ownership - it stays alive across iterations. + // This fixes the lost wakeup race condition (issue #1932). + // HandshakeEventStream wraps HandshakeHandler and implements Stream properly. + let handshake_stream = HandshakeEventStream::new(handshake_handler); + let select_stream = priority_select::ProductionPrioritySelectStream::new( + notification_channel.notifications_receiver, + notification_channel.op_execution_receiver, + conn_bridge_rx, + handshake_stream, + node_controller, + client_wait_for_transaction, + executor_listener, + peer_connections, + ); + + // Pin the stream on the stack + tokio::pin!(select_stream); + + // Reconstruct a P2pConnManager-like structure for use in the loop + // We can't use the original self because we moved conn_bridge_rx + let mut ctx = P2pConnManager { + gateways, + bridge, + conn_bridge_rx: tokio::sync::mpsc::channel(1).1, // Dummy, won't be used + event_listener, + connections, + key_pair, + listening_ip, + listening_port, + is_gateway, + this_location, + check_version, + bandwidth_limit, + blocked_addresses, + message_processor, + }; + + while let Some(result) = select_stream.as_mut().next().await { + // Process the result using the existing handler + let event = ctx + .process_select_result( + result, &mut state, - &mut handshake_handler, + &mut select_stream, &handshake_handler_msg, - &mut notification_channel, - &mut node_controller, - &mut client_wait_for_transaction, - &mut executor_listener, ) .await?; @@ -231,12 +287,17 @@ impl P2pConnManager { EventResult::Event(event) => { match *event { ConnEvent::InboundMessage(msg) => { - self.handle_inbound_message( + tracing::info!( + tx = %msg.id(), + msg_type = %msg, + peer = %ctx.bridge.op_manager.ring.connection_manager.get_peer_key().unwrap(), + "Received inbound message from peer - processing" + ); + ctx.handle_inbound_message( msg, &outbound_message, &op_manager, &mut state, - &executor_listener, ) .await?; } @@ -248,14 +309,68 @@ impl P2pConnManager { let Some(target_peer) = msg.target() else { let id = *msg.id(); tracing::error!(%id, %msg, "Target peer not set, must be set for connection outbound message"); - self.bridge.op_manager.completed(id); + ctx.bridge.op_manager.completed(id); continue; }; - tracing::debug!(%target_peer, %msg, "Sending message to peer"); - match self.connections.get(&target_peer.peer) { + + // Check if message targets self - if so, process locally instead of sending over network + let self_peer_id = ctx + .bridge + .op_manager + .ring + .connection_manager + .get_peer_key() + .unwrap(); + if target_peer.peer == self_peer_id { + tracing::error!( + tx = %msg.id(), + msg_type = %msg, + target_peer = %target_peer, + self_peer = %self_peer_id, + "BUG: OutboundMessage targets self! This indicates a routing logic error - messages should not reach OutboundMessage handler if they target self" + ); + // Convert to InboundMessage and process locally + ctx.handle_inbound_message( + msg, + &outbound_message, + &op_manager, + &mut state, + ) + .await?; + continue; + } + + tracing::info!( + tx = %msg.id(), + msg_type = %msg, + target_peer = %target_peer, + "Sending outbound message to peer" + ); + // IMPORTANT: Use a single get() call to avoid TOCTOU race + // between contains_key() and get(). The connection can be + // removed by another task between those two calls. + let peer_connection = ctx.connections.get(&target_peer.peer); + tracing::debug!( + tx = %msg.id(), + self_peer = %ctx.bridge.op_manager.ring.connection_manager.pub_key, + target = %target_peer.peer, + conn_map_size = ctx.connections.len(), + has_connection = peer_connection.is_some(), + "[CONN_TRACK] LOOKUP: Checking for existing connection in HashMap" + ); + match peer_connection { Some(peer_connection) => { - if let Err(e) = peer_connection.send(Left(msg)).await { - tracing::error!("Failed to send message to peer: {}", e); + if let Err(e) = peer_connection.send(Left(msg.clone())).await { + tracing::error!( + tx = %msg.id(), + "Failed to send message to peer: {}", e + ); + } else { + tracing::info!( + tx = %msg.id(), + target_peer = %target_peer, + "Message successfully sent to peer connection" + ); } } None => { @@ -270,7 +385,7 @@ impl P2pConnManager { let (callback, mut result) = tokio::sync::mpsc::channel(10); // Initiate connection to the peer - self.bridge + ctx.bridge .ev_listener_tx .send(Right(NodeEvent::ConnectPeer { peer: target_peer.peer.clone(), @@ -284,14 +399,29 @@ impl P2pConnManager { match timeout(Duration::from_secs(5), result.recv()).await { Ok(Some(Ok(_))) => { // Connection established, try sending again - if let Some(peer_connection) = - self.connections.get(&target_peer.peer) - { + // IMPORTANT: Use single get() call to avoid TOCTOU race + let peer_connection_retry = + ctx.connections.get(&target_peer.peer); + tracing::debug!( + tx = %msg.id(), + self_peer = %ctx.bridge.op_manager.ring.connection_manager.pub_key, + target = %target_peer.peer, + conn_map_size = ctx.connections.len(), + has_connection = peer_connection_retry.is_some(), + "[CONN_TRACK] LOOKUP: Retry after connection established - checking for connection in HashMap" + ); + if let Some(peer_connection) = peer_connection_retry { if let Err(e) = peer_connection.send(Left(msg)).await { tracing::error!("Failed to send message to peer after establishing connection: {}", e); } + } else { + tracing::error!( + tx = %tx, + target = %target_peer.peer, + "Connection established successfully but not found in HashMap - possible race condition" + ); } } Ok(Some(Err(e))) => { @@ -315,31 +445,34 @@ impl P2pConnManager { match reason { ChannelCloseReason::Handshake | ChannelCloseReason::Bridge - | ChannelCloseReason::Controller => { + | ChannelCloseReason::Controller + | ChannelCloseReason::Notification + | ChannelCloseReason::OpExecution => { // All ClosedChannel events are critical - the transport is unable to establish // more connections, rendering this peer useless. Perform cleanup and shutdown. tracing::error!( ?reason, - is_gateway = self.bridge.op_manager.ring.is_gateway(), - num_connections = self.connections.len(), + is_gateway = ctx.bridge.op_manager.ring.is_gateway(), + num_connections = ctx.connections.len(), "Critical channel closed - performing cleanup and shutting down" ); // Clean up all active connections let peers_to_cleanup: Vec<_> = - self.connections.keys().cloned().collect(); + ctx.connections.keys().cloned().collect(); for peer in peers_to_cleanup { tracing::debug!(%peer, "Cleaning up active connection due to critical channel closure"); // Clean up ring state - self.bridge + ctx.bridge .op_manager .ring .prune_connection(peer.clone()) .await; // Remove from connection map - self.connections.remove(&peer); + tracing::debug!(self_peer = %ctx.bridge.op_manager.ring.connection_manager.pub_key, %peer, conn_map_size = ctx.connections.len(), "[CONN_TRACK] REMOVE: ClosedChannel cleanup - removing from connections HashMap"); + ctx.connections.remove(&peer); // Notify handshake handler to clean up if let Err(e) = handshake_handler_msg @@ -374,8 +507,8 @@ impl P2pConnManager { } ConnEvent::NodeAction(action) => match action { NodeEvent::DropConnection(peer) => { - tracing::debug!(%peer, "Dropping connection"); - if let Some(conn) = self.connections.remove(&peer) { + tracing::debug!(self_peer = %ctx.bridge.op_manager.ring.connection_manager.pub_key, %peer, conn_map_size = ctx.connections.len(), "[CONN_TRACK] REMOVE: DropConnection event - removing from connections HashMap"); + if let Some(conn) = ctx.connections.remove(&peer) { // TODO: review: this could potentially leave garbage tasks in the background with peer listener timeout( Duration::from_secs(1), @@ -400,7 +533,7 @@ impl P2pConnManager { callback, is_gw, } => { - self.handle_connect_peer( + ctx.handle_connect_peer( peer, Box::new(callback), tx, @@ -417,10 +550,10 @@ impl P2pConnManager { %target, "SendMessage event: sending message to peer via network bridge" ); - self.bridge.send(&target, *msg).await?; + ctx.bridge.send(&target, *msg).await?; } NodeEvent::QueryConnections { callback } => { - let connections = self.connections.keys().cloned().collect(); + let connections = ctx.connections.keys().cloned().collect(); timeout( Duration::from_secs(1), callback.send(QueryResult::Connections(connections)), @@ -467,7 +600,7 @@ impl P2pConnManager { } } - let connections = self.connections.keys().cloned().collect(); + let connections = ctx.connections.keys().cloned().collect(); let debug_info = crate::message::NetworkDebugInfo { application_subscriptions: app_subscriptions, network_subscriptions: network_subs, @@ -516,7 +649,7 @@ impl P2pConnManager { // Always include basic node info, but only include address/location if available response.node_info = Some(NodeInfo { - peer_id: self.key_pair.public().to_string(), + peer_id: ctx.key_pair.public().to_string(), is_gateway: self.is_gateway, location: location.map(|loc| format!("{:.6}", loc.0)), listening_address: addr @@ -527,7 +660,7 @@ impl P2pConnManager { // Collect network information if config.include_network_info { - let connected_peers: Vec<_> = self + let connected_peers: Vec<_> = ctx .connections .keys() .map(|p| (p.to_string(), p.addr.to_string())) @@ -535,7 +668,7 @@ impl P2pConnManager { response.network_info = Some(NetworkInfo { connected_peers, - active_connections: self.connections.len(), + active_connections: ctx.connections.len(), }); } @@ -619,7 +752,7 @@ impl P2pConnManager { let seeding_contracts = op_manager.ring.all_network_subscriptions().len() as u32; response.system_metrics = Some(SystemMetrics { - active_connections: self.connections.len() as u32, + active_connections: ctx.connections.len() as u32, seeding_contracts, }); } @@ -628,7 +761,7 @@ impl P2pConnManager { if config.include_detailed_peer_info { use freenet_stdlib::client_api::ConnectedPeerInfo; // Populate detailed peer information from actual connections - for peer in self.connections.keys() { + for peer in ctx.connections.keys() { response.connected_peers_detailed.push(ConnectedPeerInfo { peer_id: peer.to_string(), address: peer.addr.to_string(), @@ -694,43 +827,19 @@ impl P2pConnManager { } } } - Err(anyhow::anyhow!( - "Network event listener exited unexpectedly" - )) + Err(anyhow::anyhow!("Network event stream ended unexpectedly")) } - /// Wait for next event using custom priority select combinator. - /// This implementation uses explicit waker control to fix waker registration issues. - #[allow(clippy::too_many_arguments)] - async fn wait_for_event( + /// Process a SelectResult from the priority select stream + async fn process_select_result( &mut self, + result: priority_select::SelectResult, state: &mut EventListenerState, - handshake_handler: &mut HandshakeHandler, + select_stream: &mut priority_select::ProductionPrioritySelectStream, handshake_handler_msg: &HanshakeHandlerMsg, - notification_channel: &mut EventLoopNotificationsReceiver, - node_controller: &mut Receiver, - client_wait_for_transaction: &mut ContractHandlerChannel, - executor_listener: &mut ExecutorToEventLoopChannel, ) -> anyhow::Result { let peer_id = &self.bridge.op_manager.ring.connection_manager.pub_key; - tracing::debug!( - peer = %peer_id, - "wait_for_event: using custom priority select combinator" - ); - - let result = priority_select::select_priority( - &mut notification_channel.notifications_receiver, - &mut notification_channel.op_execution_receiver, - &mut state.peer_connections, - &mut self.conn_bridge_rx, - handshake_handler, - node_controller, - client_wait_for_transaction, - executor_listener, - ) - .await; - use priority_select::SelectResult; match result { SelectResult::Notification(msg) => { @@ -751,10 +860,9 @@ impl P2pConnManager { SelectResult::PeerConnection(msg) => { tracing::debug!( peer = %peer_id, - num_connections = state.peer_connections.len(), "PrioritySelect: peer_connections READY" ); - self.handle_peer_connection_msg(msg, state, handshake_handler_msg) + self.handle_peer_connection_msg(msg, state, select_stream, handshake_handler_msg) .await } SelectResult::ConnBridge(msg) => { @@ -771,8 +879,13 @@ impl P2pConnManager { ); match result { Ok(event) => { - self.handle_handshake_action(event, state, handshake_handler_msg) - .await?; + self.handle_handshake_action( + event, + state, + select_stream, + handshake_handler_msg, + ) + .await?; Ok(EventResult::Continue) } Err(handshake_error) => { @@ -813,7 +926,6 @@ impl P2pConnManager { outbound_message: &OutboundMessage, op_manager: &Arc, state: &mut EventListenerState, - executor_listener: &ExecutorToEventLoopChannel, ) -> anyhow::Result<()> { match msg { NetMessage::V1(NetMessageV1::Aborted(tx)) => { @@ -824,8 +936,7 @@ impl P2pConnManager { // Forward message to transient joiner outbound_message.send_to(*addr, msg).await?; } else { - self.process_message(msg, op_manager, executor_listener, state) - .await; + self.process_message(msg, op_manager, None, state).await; } } } @@ -836,13 +947,23 @@ impl P2pConnManager { &self, msg: NetMessage, op_manager: &Arc, - executor_listener: &ExecutorToEventLoopChannel, + executor_callback_opt: Option>, state: &mut EventListenerState, ) { - let executor_callback = state - .pending_from_executor - .remove(msg.id()) - .then(|| executor_listener.callback()); + tracing::info!( + tx = %msg.id(), + tx_type = ?msg.id().transaction_type(), + msg_type = %msg, + peer = %op_manager.ring.connection_manager.get_peer_key().unwrap(), + "process_message called - processing network message" + ); + + // Only use the callback if this message was initiated by the executor + let executor_callback_opt = if state.pending_from_executor.remove(msg.id()) { + executor_callback_opt + } else { + None + }; let span = tracing::info_span!( "process_network_message", @@ -863,7 +984,7 @@ impl P2pConnManager { op_manager.clone(), self.bridge.clone(), self.event_listener.trait_clone(), - executor_callback, + executor_callback_opt, self.message_processor.clone(), pending_op_result, ) @@ -923,6 +1044,7 @@ impl P2pConnManager { &mut self, event: HandshakeEvent, state: &mut EventListenerState, + select_stream: &mut priority_select::ProductionPrioritySelectStream, _handshake_handler_msg: &HanshakeHandlerMsg, // Parameter added ) -> anyhow::Result<()> { match event { @@ -943,8 +1065,18 @@ impl P2pConnManager { return Ok(()); } } - let (tx, rx) = mpsc::channel(1); - self.connections.insert(joiner.clone(), tx); + // Only insert if connection doesn't already exist to avoid dropping existing channel + if !self.connections.contains_key(&joiner) { + let (tx, rx) = mpsc::channel(1); + tracing::debug!(self_peer = %self.bridge.op_manager.ring.connection_manager.pub_key, %joiner, %id, conn_map_size = self.connections.len(), "[CONN_TRACK] INSERT: InboundConnection - adding to connections HashMap"); + self.connections.insert(joiner.clone(), tx); + let task = peer_connection_listener(rx, conn).boxed(); + select_stream.push_peer_connection(task); + } else { + tracing::debug!(self_peer = %self.bridge.op_manager.ring.connection_manager.pub_key, %joiner, %id, conn_map_size = self.connections.len(), "[CONN_TRACK] SKIP INSERT: InboundConnection - connection already exists in HashMap, dropping new connection"); + // Connection already exists - drop the new connection object but continue processing the operation + // The conn will be dropped here which closes the duplicate connection attempt + } // IMPORTANT: Normally we do NOT add connection to ring here! // Connection should only be added after StartJoinReq is accepted @@ -978,8 +1110,6 @@ impl P2pConnManager { .push(id, crate::operations::OpEnum::Connect(op)) .await?; } - let task = peer_connection_listener(rx, conn).boxed(); - state.peer_connections.push(task); if let Some(ForwardInfo { target: forward_to, @@ -1012,7 +1142,7 @@ impl P2pConnManager { peer_id, connection, } => { - self.handle_successful_connection(peer_id, connection, state, None) + self.handle_successful_connection(peer_id, connection, state, select_stream, None) .await?; } HandshakeEvent::OutboundGatewayConnectionSuccessful { @@ -1024,6 +1154,7 @@ impl P2pConnManager { peer_id, connection, state, + select_stream, Some(remaining_checks), ) .await?; @@ -1090,6 +1221,7 @@ impl P2pConnManager { peer_id: PeerId, connection: PeerConnection, state: &mut EventListenerState, + select_stream: &mut priority_select::ProductionPrioritySelectStream, remaining_checks: Option, ) -> anyhow::Result<()> { if let Some(mut cb) = state.awaiting_connection.remove(&peer_id.addr) { @@ -1119,10 +1251,17 @@ impl P2pConnManager { } else { tracing::warn!(%peer_id, "No callback for connection established"); } - let (tx, rx) = mpsc::channel(10); - self.connections.insert(peer_id.clone(), tx); - let task = peer_connection_listener(rx, connection).boxed(); - state.peer_connections.push(task); + + // Only insert if connection doesn't already exist to avoid dropping existing channel + if !self.connections.contains_key(&peer_id) { + let (tx, rx) = mpsc::channel(10); + tracing::debug!(self_peer = %self.bridge.op_manager.ring.connection_manager.pub_key, %peer_id, conn_map_size = self.connections.len(), "[CONN_TRACK] INSERT: OutboundConnectionSuccessful - adding to connections HashMap"); + self.connections.insert(peer_id.clone(), tx); + let task = peer_connection_listener(rx, connection).boxed(); + select_stream.push_peer_connection(task); + } else { + tracing::debug!(self_peer = %self.bridge.op_manager.ring.connection_manager.pub_key, %peer_id, conn_map_size = self.connections.len(), "[CONN_TRACK] SKIP INSERT: OutboundConnectionSuccessful - connection already exists in HashMap"); + } Ok(()) } @@ -1130,6 +1269,7 @@ impl P2pConnManager { &mut self, msg: Option>, state: &mut EventListenerState, + select_stream: &mut priority_select::ProductionPrioritySelectStream, handshake_handler_msg: &HanshakeHandlerMsg, ) -> anyhow::Result { match msg { @@ -1167,7 +1307,7 @@ impl P2pConnManager { } let task = peer_connection_listener(peer_conn.rx, peer_conn.conn).boxed(); - state.peer_connections.push(task); + select_stream.push_peer_connection(task); Ok(EventResult::Event( ConnEvent::InboundMessage(peer_conn.msg).into(), )) @@ -1179,7 +1319,7 @@ impl P2pConnManager { .keys() .find_map(|k| (k.addr == socket_addr).then(|| k.clone())) { - tracing::debug!(%peer, "Dropping connection"); + tracing::debug!(self_peer = %self.bridge.op_manager.ring.connection_manager.pub_key, %peer, socket_addr = %socket_addr, conn_map_size = self.connections.len(), "[CONN_TRACK] REMOVE: TransportError::ConnectionClosed - removing from connections HashMap"); self.bridge .op_manager .ring @@ -1201,6 +1341,38 @@ impl P2pConnManager { fn handle_notification_msg(&self, msg: Option>) -> EventResult { match msg { Some(Left(msg)) => { + // Check if message has a target peer - if so, route as outbound, otherwise process locally + if let Some(target) = msg.target() { + let self_peer = self + .bridge + .op_manager + .ring + .connection_manager + .get_peer_key() + .unwrap(); + + tracing::debug!( + tx = %msg.id(), + msg_type = %msg, + target_peer = %target, + self_peer = %self_peer, + target_equals_self = (target.peer == self_peer), + "[ROUTING] handle_notification_msg: Checking if message targets self" + ); + + if target.peer != self_peer { + // Message targets another peer - send as outbound + tracing::info!( + tx = %msg.id(), + msg_type = %msg, + target_peer = %target, + "handle_notification_msg: Message has target peer, routing as OutboundMessage" + ); + return EventResult::Event(ConnEvent::OutboundMessage(msg).into()); + } + } + + // Message targets self or has no target - process locally tracing::debug!( tx = %msg.id(), msg_type = %msg, @@ -1212,7 +1384,9 @@ impl P2pConnManager { tracing::debug!("handle_notification_msg: Received NodeEvent notification"); EventResult::Event(ConnEvent::NodeAction(action).into()) } - None => EventResult::Continue, + None => EventResult::Event( + ConnEvent::ClosedChannel(ChannelCloseReason::Notification).into(), + ), } } @@ -1226,13 +1400,17 @@ impl P2pConnManager { state.pending_op_results.insert(*msg.id(), callback); EventResult::Event(ConnEvent::InboundMessage(msg).into()) } - _ => EventResult::Continue, + None => { + EventResult::Event(ConnEvent::ClosedChannel(ChannelCloseReason::OpExecution).into()) + } } } fn handle_bridge_msg(&self, msg: Option) -> EventResult { match msg { - Some(Left((_, msg))) => EventResult::Event(ConnEvent::OutboundMessage(*msg).into()), + Some(Left((_target, msg))) => { + EventResult::Event(ConnEvent::OutboundMessage(*msg).into()) + } Some(Right(action)) => EventResult::Event(ConnEvent::NodeAction(action).into()), None => EventResult::Event(ConnEvent::ClosedChannel(ChannelCloseReason::Bridge).into()), } @@ -1342,8 +1520,7 @@ impl ConnectResultSender for mpsc::Sender), ()>> { } struct EventListenerState { - peer_connections: - FuturesUnordered>>, + // Note: peer_connections has been moved out to allow separate borrowing by the stream pending_from_executor: HashSet, // FIXME: we are potentially leaving trash here when transacrions are completed tx_to_client: HashMap>, @@ -1356,7 +1533,6 @@ struct EventListenerState { impl EventListenerState { fn new() -> Self { Self { - peer_connections: FuturesUnordered::new(), pending_from_executor: HashSet::new(), tx_to_client: HashMap::new(), client_waiting_transaction: Vec::new(), @@ -1388,6 +1564,10 @@ pub(super) enum ChannelCloseReason { Bridge, /// Node controller channel closed - critical, must shutdown gracefully Controller, + /// Notification channel closed - critical, must shutdown gracefully + Notification, + /// Op execution channel closed - critical, must shutdown gracefully + OpExecution, } #[allow(dead_code)] @@ -1398,6 +1578,7 @@ enum ProtocolStatus { Failed, } +#[derive(Debug)] pub(super) struct PeerConnectionInbound { pub conn: PeerConnection, /// Receiver for inbound messages for the peer connection diff --git a/crates/core/src/node/network_bridge/priority_select.rs b/crates/core/src/node/network_bridge/priority_select.rs index 9cd32d508..99a84fb48 100644 --- a/crates/core/src/node/network_bridge/priority_select.rs +++ b/crates/core/src/node/network_bridge/priority_select.rs @@ -2,9 +2,7 @@ //! This avoids waker registration issues that can occur with nested tokio::select! macros. use either::Either; -use futures::future::BoxFuture; -use futures::stream::{FuturesUnordered, Stream}; -use pin_project::pin_project; +use futures::{future::BoxFuture, stream::FuturesUnordered, Stream}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -17,13 +15,14 @@ use crate::contract::{ }; use crate::dev_tool::{PeerId, Transaction}; use crate::message::{NetMessage, NodeEvent}; -use crate::node::network_bridge::handshake::{HandshakeError, HandshakeHandler}; +use crate::node::network_bridge::handshake::HandshakeError; use crate::transport::TransportError; // P2pBridgeEvent type alias for the event bridge channel pub type P2pBridgeEvent = Either<(PeerId, Box), NodeEvent>; #[allow(clippy::large_enum_variant)] +#[derive(Debug)] pub(super) enum SelectResult { Notification(Option>), OpExecution(Option<(tokio::sync::mpsc::Sender, NetMessage)>), @@ -43,204 +42,2173 @@ pub(super) enum SelectResult { ExecutorTransaction(Result), } -/// A future that polls multiple futures with explicit priority order and waker control. -/// Uses pinned BoxFutures that are created once and reused across polls to maintain -/// waker registration and future state (including handshake state machine). -#[pin_project] -pub(super) struct PrioritySelectFuture<'a> { - #[pin] - notification_fut: BoxFuture<'a, Option>>, - #[pin] - op_execution_fut: BoxFuture<'a, Option<(tokio::sync::mpsc::Sender, NetMessage)>>, - #[pin] - peer_connections: - &'a mut FuturesUnordered>>, - #[pin] - conn_bridge_fut: BoxFuture<'a, Option>, - #[pin] - handshake_fut: - BoxFuture<'a, Result>, - #[pin] - node_controller_fut: BoxFuture<'a, Option>, - #[pin] - client_transaction_fut: BoxFuture< - 'a, - Result< +/// Trait for types that can relay client transaction results +pub(super) trait ClientTransactionRelay: Send + Unpin { + fn relay_transaction_result_to_client( + &mut self, + ) -> impl Future< + Output = Result< ( crate::client_events::ClientId, crate::contract::WaitingTransaction, ), anyhow::Error, >, - >, - #[pin] - executor_transaction_fut: BoxFuture<'a, Result>, - peer_connections_empty: bool, + > + Send; +} + +/// Trait for types that can receive transactions from executor +pub(super) trait ExecutorTransactionReceiver: Send + Unpin { + fn transaction_from_executor( + &mut self, + ) -> impl Future> + Send; +} + +impl ClientTransactionRelay for ContractHandlerChannel { + fn relay_transaction_result_to_client( + &mut self, + ) -> impl Future< + Output = Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + >, + > + Send { + self.relay_transaction_result_to_client() + } +} + +impl ExecutorTransactionReceiver for ExecutorToEventLoopChannel { + fn transaction_from_executor( + &mut self, + ) -> impl Future> + Send { + self.transaction_from_executor() + } } -impl<'a> PrioritySelectFuture<'a> { +/// Type alias for the production PrioritySelectStream with concrete types +pub(super) type ProductionPrioritySelectStream = PrioritySelectStream< + super::handshake::HandshakeEventStream, + ContractHandlerChannel, + ExecutorToEventLoopChannel, +>; + +/// Generic stream-based priority select that owns simple Receivers as streams +/// and holds references to complex event sources. +/// This fixes the lost wakeup race condition (issue #1932) by keeping the stream +/// alive across loop iterations, maintaining waker registration. +pub(super) struct PrioritySelectStream +where + H: Stream> + Unpin, + C: ClientTransactionRelay, + E: ExecutorTransactionReceiver, +{ + // Streams created from owned receivers + notification: tokio_stream::wrappers::ReceiverStream>, + op_execution: + tokio_stream::wrappers::ReceiverStream<(tokio::sync::mpsc::Sender, NetMessage)>, + conn_bridge: tokio_stream::wrappers::ReceiverStream, + node_controller: tokio_stream::wrappers::ReceiverStream, + + // FuturesUnordered already implements Stream (owned) + peer_connections: + FuturesUnordered>>, + + // HandshakeHandler now implements Stream directly - maintains state across polls + // Generic to allow testing with mocks + handshake_handler: H, + + // These two are owned and we create futures from them that poll their internal state + // Generic to allow testing with mocks + client_wait_for_transaction: C, + executor_listener: E, + + // Track which channels have been reported as closed (to avoid infinite loop of closure notifications) + notification_closed: bool, + op_execution_closed: bool, + conn_bridge_closed: bool, + node_controller_closed: bool, +} + +impl PrioritySelectStream +where + H: Stream> + Unpin, + C: ClientTransactionRelay, + E: ExecutorTransactionReceiver, +{ #[allow(clippy::too_many_arguments)] pub fn new( - notification_fut: BoxFuture<'a, Option>>, - op_execution_fut: BoxFuture< - 'a, - Option<(tokio::sync::mpsc::Sender, NetMessage)>, - >, - peer_connections: &'a mut FuturesUnordered< + notification_rx: Receiver>, + op_execution_rx: Receiver<(tokio::sync::mpsc::Sender, NetMessage)>, + conn_bridge_rx: Receiver, + handshake_handler: H, + node_controller: Receiver, + client_wait_for_transaction: C, + executor_listener: E, + peer_connections: FuturesUnordered< BoxFuture<'static, Result>, >, - conn_bridge_fut: BoxFuture<'a, Option>, - handshake_fut: BoxFuture< - 'a, - Result, - >, - node_controller_fut: BoxFuture<'a, Option>, - client_transaction_fut: BoxFuture< - 'a, - Result< - ( - crate::client_events::ClientId, - crate::contract::WaitingTransaction, - ), - anyhow::Error, - >, - >, - executor_transaction_fut: BoxFuture<'a, Result>, ) -> Self { - let peer_connections_empty = peer_connections.is_empty(); + use tokio_stream::wrappers::ReceiverStream; Self { - notification_fut, - op_execution_fut, + notification: ReceiverStream::new(notification_rx), + op_execution: ReceiverStream::new(op_execution_rx), + conn_bridge: ReceiverStream::new(conn_bridge_rx), + node_controller: ReceiverStream::new(node_controller), peer_connections, - conn_bridge_fut, - handshake_fut, - node_controller_fut, - client_transaction_fut, - executor_transaction_fut, - peer_connections_empty, + handshake_handler, + client_wait_for_transaction, + executor_listener, + notification_closed: false, + op_execution_closed: false, + conn_bridge_closed: false, + node_controller_closed: false, } } + + /// Add a new peer connection task to the stream + pub fn push_peer_connection( + &mut self, + task: BoxFuture<'static, Result>, + ) { + self.peer_connections.push(task); + } } -impl<'a> Future for PrioritySelectFuture<'a> { - type Output = SelectResult; +impl Stream for PrioritySelectStream +where + H: Stream> + Unpin, + C: ClientTransactionRelay, + E: ExecutorTransactionReceiver, +{ + type Item = SelectResult; - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let mut this = self.project(); + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + + // Track if any channel closed (to report after checking all sources) + let mut first_closed_channel: Option = None; // Priority 1: Notification channel (highest priority) - // This MUST be polled first to ensure operation state machine messages - // are processed before network messages - match this.notification_fut.as_mut().poll(cx) { - Poll::Ready(msg) => { - tracing::trace!("PrioritySelect: notification_rx ready"); - return Poll::Ready(SelectResult::Notification(msg)); + if !this.notification_closed { + match Pin::new(&mut this.notification).poll_next(cx) { + Poll::Ready(Some(msg)) => { + return Poll::Ready(Some(SelectResult::Notification(Some(msg)))) + } + Poll::Ready(None) => { + // Channel closed - record it and mark as closed to avoid re-polling + this.notification_closed = true; + if first_closed_channel.is_none() { + first_closed_channel = Some(SelectResult::Notification(None)); + } + } + Poll::Pending => {} } - Poll::Pending => {} } - // Priority 2: Op execution channel - match this.op_execution_fut.as_mut().poll(cx) { - Poll::Ready(msg) => { - tracing::trace!("PrioritySelect: op_execution_rx ready"); - return Poll::Ready(SelectResult::OpExecution(msg)); + // Priority 2: Op execution + if !this.op_execution_closed { + match Pin::new(&mut this.op_execution).poll_next(cx) { + Poll::Ready(Some(msg)) => { + return Poll::Ready(Some(SelectResult::OpExecution(Some(msg)))) + } + Poll::Ready(None) => { + // Channel closed - record it and mark as closed to avoid re-polling + this.op_execution_closed = true; + if first_closed_channel.is_none() { + first_closed_channel = Some(SelectResult::OpExecution(None)); + } + } + Poll::Pending => {} } - Poll::Pending => {} } // Priority 3: Peer connections (only if not empty) - if !*this.peer_connections_empty { - match Stream::poll_next(this.peer_connections.as_mut(), cx) { - Poll::Ready(msg) => { - tracing::trace!("PrioritySelect: peer_connections ready"); - return Poll::Ready(SelectResult::PeerConnection(msg)); - } + if !this.peer_connections.is_empty() { + match Pin::new(&mut this.peer_connections).poll_next(cx) { + Poll::Ready(msg) => return Poll::Ready(Some(SelectResult::PeerConnection(msg))), Poll::Pending => {} } } // Priority 4: Connection bridge - match this.conn_bridge_fut.as_mut().poll(cx) { - Poll::Ready(msg) => { - tracing::trace!("PrioritySelect: conn_bridge_rx ready"); - return Poll::Ready(SelectResult::ConnBridge(msg)); + if !this.conn_bridge_closed { + match Pin::new(&mut this.conn_bridge).poll_next(cx) { + Poll::Ready(Some(msg)) => { + return Poll::Ready(Some(SelectResult::ConnBridge(Some(msg)))) + } + Poll::Ready(None) => { + // Channel closed - record it and mark as closed to avoid re-polling + this.conn_bridge_closed = true; + if first_closed_channel.is_none() { + first_closed_channel = Some(SelectResult::ConnBridge(None)); + } + } + Poll::Pending => {} } - Poll::Pending => {} } - // Priority 5: Handshake handler (poll wait_for_events as a whole to preserve all logic) - // The handshake future is pinned in the struct and reused across polls, - // preserving the internal state machine of wait_for_events() - match this.handshake_fut.as_mut().poll(cx) { - Poll::Ready(result) => { - tracing::trace!("PrioritySelect: handshake_handler ready"); - return Poll::Ready(SelectResult::Handshake(result)); - } + // Priority 5: Handshake handler (now implements Stream) + // Poll the handshake handler stream - it maintains state across polls + match Pin::new(&mut this.handshake_handler).poll_next(cx) { + Poll::Ready(Some(result)) => return Poll::Ready(Some(SelectResult::Handshake(result))), + Poll::Ready(None) => {} // Stream ended (shouldn't happen in practice) Poll::Pending => {} } - // Priority 8: Node controller - match this.node_controller_fut.as_mut().poll(cx) { - Poll::Ready(msg) => { - tracing::trace!("PrioritySelect: node_controller ready"); - return Poll::Ready(SelectResult::NodeController(msg)); + // Priority 6: Node controller + if !this.node_controller_closed { + match Pin::new(&mut this.node_controller).poll_next(cx) { + Poll::Ready(Some(msg)) => { + return Poll::Ready(Some(SelectResult::NodeController(Some(msg)))) + } + Poll::Ready(None) => { + // Channel closed - record it and mark as closed to avoid re-polling + this.node_controller_closed = true; + if first_closed_channel.is_none() { + first_closed_channel = Some(SelectResult::NodeController(None)); + } + } + Poll::Pending => {} } - Poll::Pending => {} } - // Priority 9: Client transaction waiting - match this.client_transaction_fut.as_mut().poll(cx) { - Poll::Ready(event_id) => { - tracing::trace!("PrioritySelect: client_wait_for_transaction ready"); - return Poll::Ready(SelectResult::ClientTransaction(event_id)); + // Priority 7: Client transaction + let client_fut = this + .client_wait_for_transaction + .relay_transaction_result_to_client(); + tokio::pin!(client_fut); + match client_fut.poll(cx) { + Poll::Ready(result) => { + return Poll::Ready(Some(SelectResult::ClientTransaction(result))) } Poll::Pending => {} } - // Priority 10: Executor transaction - match this.executor_transaction_fut.as_mut().poll(cx) { - Poll::Ready(id) => { - tracing::trace!("PrioritySelect: executor_listener ready"); - return Poll::Ready(SelectResult::ExecutorTransaction(id)); + // Priority 8: Executor transaction + let executor_fut = this.executor_listener.transaction_from_executor(); + tokio::pin!(executor_fut); + match executor_fut.poll(cx) { + Poll::Ready(result) => { + return Poll::Ready(Some(SelectResult::ExecutorTransaction(result))) } Poll::Pending => {} } - // All futures returned Pending - wakers are now registered for all of them - // The key difference from the broken implementation: these are the SAME futures - // being polled repeatedly, so their wakers persist and internal state is preserved - tracing::trace!("PrioritySelect: all pending"); + // If a channel closed and nothing else is ready, report the closure + if let Some(closed) = first_closed_channel { + return Poll::Ready(Some(closed)); + } + + // All pending Poll::Pending } } -#[allow(clippy::too_many_arguments)] -pub(super) async fn select_priority<'a>( - notification_rx: &'a mut Receiver>, - op_execution_rx: &'a mut Receiver<(tokio::sync::mpsc::Sender, NetMessage)>, - peer_connections: &'a mut FuturesUnordered< - BoxFuture<'static, Result>, - >, - conn_bridge_rx: &'a mut Receiver, - handshake_handler: &'a mut HandshakeHandler, - node_controller: &'a mut Receiver, - client_wait_for_transaction: &'a mut ContractHandlerChannel, - executor_listener: &'a mut ExecutorToEventLoopChannel, -) -> SelectResult { - // Create boxed futures ONCE - they will be pinned and reused across polls. - // This is critical: the futures must persist across multiple poll() calls to: - // 1. Maintain waker registration (so the runtime can wake the task) - // 2. Preserve internal state (especially the handshake state machine) - PrioritySelectFuture::new( - Box::pin(notification_rx.recv()), - Box::pin(op_execution_rx.recv()), - peer_connections, - Box::pin(conn_bridge_rx.recv()), - Box::pin(handshake_handler.wait_for_events()), - Box::pin(node_controller.recv()), - Box::pin(client_wait_for_transaction.relay_transaction_result_to_client()), - Box::pin(executor_listener.transaction_from_executor()), - ) - .await +#[cfg(test)] +mod tests { + use super::*; + use futures::stream::StreamExt; + use tokio::sync::mpsc; + use tokio::time::{sleep, timeout, Duration}; + + /// Mock HandshakeStream for testing that pends forever + struct MockHandshakeStream; + + impl Stream for MockHandshakeStream { + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Poll::Pending + } + } + + /// Create a mock HandshakeStream for testing + fn create_mock_handshake_stream() -> MockHandshakeStream { + MockHandshakeStream + } + + /// Test PrioritySelectStream with notification arriving after initial poll + #[tokio::test] + #[test_log::test] + async fn test_priority_select_future_wakeup() { + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + let (notif_tx, notif_rx) = mpsc::channel(10); + let (_op_tx, op_rx) = mpsc::channel(10); + let peers = FuturesUnordered::new(); + let (_bridge_tx, bridge_rx) = mpsc::channel(10); + let (_node_tx, node_rx) = mpsc::channel(10); + + // Spawn task that sends notification after delay + let notif_tx_clone = notif_tx.clone(); + tokio::spawn(async move { + sleep(Duration::from_millis(50)).await; + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + notif_tx_clone.send(Either::Left(test_msg)).await.unwrap(); + }); + + // Create stream - should be pending initially, then wake up when message arrives + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + // Should complete when message arrives (notification has priority over handshake) + let result = timeout(Duration::from_millis(200), stream.next()).await; + + assert!( + result.is_ok(), + "Select stream should wake up when notification arrives" + ); + + let select_result = result.unwrap().expect("Stream should yield value"); + match select_result { + SelectResult::Notification(Some(_)) => {} + SelectResult::Notification(None) => panic!("Got Notification(None)"), + SelectResult::OpExecution(_) => panic!("Got OpExecution"), + SelectResult::PeerConnection(_) => panic!("Got PeerConnection"), + SelectResult::ConnBridge(_) => panic!("Got ConnBridge"), + SelectResult::Handshake(_) => panic!("Got Handshake"), + SelectResult::NodeController(_) => panic!("Got NodeController"), + SelectResult::ClientTransaction(_) => panic!("Got ClientTransaction"), + SelectResult::ExecutorTransaction(_) => panic!("Got ExecutorTransaction"), + } + } + + /// Test that notification has priority over other channels in PrioritySelectStream + #[tokio::test] + #[test_log::test] + async fn test_priority_select_future_priority_ordering() { + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + let (notif_tx, notif_rx) = mpsc::channel(10); + let (op_tx, op_rx) = mpsc::channel(10); + let peers = FuturesUnordered::new(); + let (bridge_tx, bridge_rx) = mpsc::channel(10); + let (_, node_rx) = mpsc::channel(10); + + // Send to multiple channels - notification should be received first + let (callback_tx, _) = mpsc::channel(1); + let dummy_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + op_tx.send((callback_tx, dummy_msg.clone())).await.unwrap(); + bridge_tx + .send(Either::Right(NodeEvent::Disconnect { cause: None })) + .await + .unwrap(); + + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + notif_tx.send(Either::Left(test_msg)).await.unwrap(); + + // Create and poll the stream + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + let result = timeout(Duration::from_millis(100), stream.next()).await; + assert!(result.is_ok()); + + match result.unwrap().expect("Stream should yield value") { + SelectResult::Notification(_) => {} + _ => panic!("Notification should be received first due to priority"), + } + } + + /// Test concurrent messages - simpler version that sends all messages first + #[tokio::test] + #[test_log::test] + async fn test_priority_select_future_concurrent_messages() { + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + let (notif_tx, notif_rx) = mpsc::channel(100); + let peers = FuturesUnordered::new(); + + // Send all 15 messages + for _ in 0..15 { + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + notif_tx.send(Either::Left(test_msg)).await.unwrap(); + } + + // Receive first message + let (_, op_rx) = mpsc::channel(10); + let (_, bridge_rx) = mpsc::channel(10); + let (_, node_rx) = mpsc::channel(10); + + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + let result = timeout(Duration::from_millis(100), stream.next()).await; + assert!(result.is_ok(), "Should receive first message"); + match result.unwrap().expect("Stream should yield value") { + SelectResult::Notification(Some(_)) => {} + _ => panic!("Expected notification"), + } + } + + /// Test that messages arrive in buffered channel before receiver polls + #[tokio::test] + #[test_log::test] + async fn test_priority_select_future_buffered_messages() { + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + let (notif_tx, notif_rx) = mpsc::channel(10); + let peers = FuturesUnordered::new(); + + // Send message BEFORE creating stream + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + notif_tx.send(Either::Left(test_msg)).await.unwrap(); + + // Create stream - should receive the buffered message immediately + let (_, op_rx) = mpsc::channel(10); + let (_, bridge_rx) = mpsc::channel(10); + let (_, node_rx) = mpsc::channel(10); + + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + let result = timeout(Duration::from_millis(100), stream.next()).await; + assert!( + result.is_ok(), + "Should receive buffered message immediately" + ); + + match result.unwrap().expect("Stream should yield value") { + SelectResult::Notification(Some(_)) => {} + _ => panic!("Expected notification"), + } + } + + /// Test rapid polling of stream with short timeouts + #[tokio::test] + #[test_log::test] + async fn test_priority_select_future_rapid_cancellations() { + use futures::StreamExt; + + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + let (notif_tx, notif_rx) = mpsc::channel(100); + let peers = FuturesUnordered::new(); + + // Send 10 messages + for _ in 0..10 { + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + notif_tx.send(Either::Left(test_msg)).await.unwrap(); + } + + let (_, op_rx) = mpsc::channel(10); + let (_, bridge_rx) = mpsc::channel(10); + let (_, node_rx) = mpsc::channel(10); + + // Create stream once - it maintains waker registration across polls + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + // Rapidly poll stream with short timeouts (simulating cancellations) + let mut received = 0; + for _ in 0..30 { + if let Ok(Some(SelectResult::Notification(Some(_)))) = + timeout(Duration::from_millis(5), stream.as_mut().next()).await + { + received += 1; + } + + if received >= 10 { + break; + } + } + + assert_eq!( + received, 10, + "Should receive all messages despite rapid cancellations" + ); + } + + /// Test simulating wait_for_event loop behavior - using stream that maintains waker registration + /// This test verifies that PrioritySelectStream properly maintains waker registration across + /// multiple .next().await calls, unlike the old approach that recreated futures each iteration. + /// + /// Enhanced version: sends MULTIPLE messages per channel to verify interleaving and priority. + #[tokio::test] + #[test_log::test] + async fn test_priority_select_event_loop_simulation() { + use futures::StreamExt; + + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + // Create channels once (like in wait_for_event) + let (notif_tx, notif_rx) = mpsc::channel::>(10); + let (op_tx, op_rx) = + mpsc::channel::<(tokio::sync::mpsc::Sender, NetMessage)>(10); + let peers = FuturesUnordered::new(); + let (bridge_tx, bridge_rx) = mpsc::channel::(10); + let (node_tx, node_rx) = mpsc::channel::(10); + + // Spawn task that sends MULTIPLE messages to different channels + let notif_tx_clone = notif_tx.clone(); + let op_tx_clone = op_tx.clone(); + let bridge_tx_clone = bridge_tx.clone(); + let node_tx_clone = node_tx.clone(); + tokio::spawn(async move { + sleep(Duration::from_millis(10)).await; + + // Send 3 notifications + for i in 0..3 { + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + tracing::info!("Sending notification {}", i); + notif_tx_clone.send(Either::Left(test_msg)).await.unwrap(); + } + + // Send 2 op execution messages + for i in 0..2 { + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + let (callback_tx, _) = mpsc::channel(1); + tracing::info!("Sending op_execution {}", i); + op_tx_clone.send((callback_tx, test_msg)).await.unwrap(); + } + + // Send 2 bridge events + for i in 0..2 { + tracing::info!("Sending bridge event {}", i); + bridge_tx_clone + .send(Either::Right(NodeEvent::Disconnect { cause: None })) + .await + .unwrap(); + } + + // Send 1 node controller event + tracing::info!("Sending node controller event"); + node_tx_clone + .send(NodeEvent::Disconnect { cause: None }) + .await + .unwrap(); + }); + + // Create stream ONCE - maintains waker registration across iterations + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + let mut received_events = Vec::new(); + + // Simulate event loop: poll stream until we've received all expected messages (3+2+2+1 = 8) + let expected_count = 8; + for iteration in 0..expected_count { + tracing::info!("Event loop iteration {}", iteration); + + // Poll the SAME stream on each iteration - waker registration is maintained + let result = timeout(Duration::from_millis(50), stream.as_mut().next()).await; + assert!(result.is_ok(), "Iteration {} should complete", iteration); + + let event = result.unwrap().expect("Stream should yield value"); + match &event { + SelectResult::Notification(_) => received_events.push("notification"), + SelectResult::OpExecution(_) => received_events.push("op_execution"), + SelectResult::ConnBridge(_) => received_events.push("conn_bridge"), + SelectResult::Handshake(_) => received_events.push("handshake"), + SelectResult::NodeController(_) => received_events.push("node_controller"), + _ => received_events.push("other"), + } + + tracing::info!( + "Event loop iteration {} received: {:?}", + iteration, + received_events.last() + ); + } + + // Verify we received all expected messages + assert_eq!( + received_events.len(), + expected_count, + "Should receive all {} messages", + expected_count + ); + + // Count each type + let notif_count = received_events + .iter() + .filter(|&e| *e == "notification") + .count(); + let op_count = received_events + .iter() + .filter(|&e| *e == "op_execution") + .count(); + let bridge_count = received_events + .iter() + .filter(|&e| *e == "conn_bridge") + .count(); + let node_count = received_events + .iter() + .filter(|&e| *e == "node_controller") + .count(); + + tracing::info!("Received counts - notifications: {}, op_execution: {}, bridge: {}, node_controller: {}", + notif_count, op_count, bridge_count, node_count); + + assert_eq!(notif_count, 3, "Should receive 3 notifications"); + assert_eq!(op_count, 2, "Should receive 2 op_execution messages"); + assert_eq!(bridge_count, 2, "Should receive 2 bridge messages"); + assert_eq!(node_count, 1, "Should receive 1 node_controller message"); + + // Verify priority ordering: all notifications should come before any op_execution + // which should come before any bridge events + let first_notif_idx = received_events.iter().position(|e| *e == "notification"); + let last_notif_idx = received_events.iter().rposition(|e| *e == "notification"); + let first_op_idx = received_events.iter().position(|e| *e == "op_execution"); + let last_op_idx = received_events.iter().rposition(|e| *e == "op_execution"); + let first_bridge_idx = received_events.iter().position(|e| *e == "conn_bridge"); + + // All notifications should come first (indices 0, 1, 2) + assert_eq!( + first_notif_idx, + Some(0), + "First notification should be at index 0" + ); + assert_eq!( + last_notif_idx, + Some(2), + "Last notification should be at index 2" + ); + + // All op_executions should come after notifications (indices 3, 4) + assert!( + first_op_idx.unwrap() > last_notif_idx.unwrap(), + "Op execution should come after all notifications" + ); + assert_eq!( + first_op_idx, + Some(3), + "First op_execution should be at index 3" + ); + assert_eq!( + last_op_idx, + Some(4), + "Last op_execution should be at index 4" + ); + + // All bridge events should come after op_executions (indices 5, 6) + assert!( + first_bridge_idx.unwrap() > last_op_idx.unwrap(), + "Bridge events should come after all op_executions" + ); + + tracing::info!( + "✓ All {} messages received in correct priority order: {:?}", + expected_count, + received_events + ); + + // Clean up - drop senders to close channels + drop(notif_tx); + drop(op_tx); + drop(bridge_tx); + drop(node_tx); + // client_tx and executor_tx were moved into MockClient and MockExecutor + } + + /// Stress test: Multiple concurrent tasks sending messages with random delays + /// This test verifies that priority ordering is maintained even under concurrent load + /// with unpredictable timing. Each channel has its own task sending messages at random + /// intervals, and we verify all messages are received in perfect priority order. + /// + /// Uses seeded RNG for reproducibility - run with 5 different seeds to ensure robustness. + #[tokio::test] + #[test_log::test] + async fn test_priority_select_concurrent_random_stress() { + test_with_seed(42).await; + test_with_seed(123).await; + test_with_seed(999).await; + test_with_seed(7777).await; + test_with_seed(31415).await; + } + + async fn test_with_seed(seed: u64) { + use rand::rngs::StdRng; + use rand::Rng; + use rand::SeedableRng; + + tracing::info!("=== Stress test with seed {} ===", seed); + + // Define how many messages each sender will send + // Using 2 orders of magnitude more messages to stress test (17 -> 1700) + const NOTIF_COUNT: usize = 500; + const OP_COUNT: usize = 400; + const BRIDGE_COUNT: usize = 300; + const NODE_COUNT: usize = 200; + const CLIENT_COUNT: usize = 200; + const EXECUTOR_COUNT: usize = 100; + const TOTAL_MESSAGES: usize = + NOTIF_COUNT + OP_COUNT + BRIDGE_COUNT + NODE_COUNT + CLIENT_COUNT + EXECUTOR_COUNT; + + // Pre-generate all random delays using seeded RNG + // Most delays are in microseconds (50-500us) with occasional millisecond outliers (1-5ms) + // This keeps the test fast while still testing timing variations + let mut rng = StdRng::seed_from_u64(seed); + let make_delays = |count: usize, rng: &mut StdRng| -> Vec { + (0..count) + .map(|_| { + // 10% chance of millisecond delay (outlier), 90% microsecond delay + if rng.random_range(0..10) == 0 { + rng.random_range(1000..5000) // 1-5ms outliers + } else { + rng.random_range(50..500) // 50-500us typical + } + }) + .collect() + }; + + let notif_delays = make_delays(NOTIF_COUNT, &mut rng); + let op_delays = make_delays(OP_COUNT, &mut rng); + let bridge_delays = make_delays(BRIDGE_COUNT, &mut rng); + let node_delays = make_delays(NODE_COUNT, &mut rng); + let client_delays = make_delays(CLIENT_COUNT, &mut rng); + let executor_delays = make_delays(EXECUTOR_COUNT, &mut rng); + + // Create channels once (like in wait_for_event) + let (notif_tx, notif_rx) = mpsc::channel::>(100); + let (op_tx, op_rx) = + mpsc::channel::<(tokio::sync::mpsc::Sender, NetMessage)>(100); + let peers = FuturesUnordered::new(); + let (bridge_tx, bridge_rx) = mpsc::channel::(100); + let (node_tx, node_rx) = mpsc::channel::(100); + let (client_tx, client_rx) = mpsc::channel::< + Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + >, + >(100); + let (executor_tx, executor_rx) = mpsc::channel::>(100); + + tracing::info!( + "Starting stress test with {} total messages from 6 concurrent tasks", + TOTAL_MESSAGES + ); + + // Spawn separate task for each channel with pre-generated delays + let notif_handle = tokio::spawn(async move { + for (i, &delay_us) in notif_delays.iter().enumerate() { + sleep(Duration::from_micros(delay_us)).await; + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + tracing::debug!( + "Notification task sending message {} after {}us delay", + i, + delay_us + ); + notif_tx.send(Either::Left(test_msg)).await.unwrap(); + } + tracing::info!("Notification task sent all {} messages", NOTIF_COUNT); + NOTIF_COUNT + }); + + let op_handle = tokio::spawn(async move { + for (i, &delay_us) in op_delays.iter().enumerate() { + sleep(Duration::from_micros(delay_us)).await; + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + let (callback_tx, _) = mpsc::channel(1); + tracing::debug!( + "OpExecution task sending message {} after {}us delay", + i, + delay_us + ); + op_tx.send((callback_tx, test_msg)).await.unwrap(); + } + tracing::info!("OpExecution task sent all {} messages", OP_COUNT); + OP_COUNT + }); + + let bridge_handle = tokio::spawn(async move { + for (i, &delay_us) in bridge_delays.iter().enumerate() { + sleep(Duration::from_micros(delay_us)).await; + tracing::debug!( + "Bridge task sending message {} after {}us delay", + i, + delay_us + ); + bridge_tx + .send(Either::Right(NodeEvent::Disconnect { cause: None })) + .await + .unwrap(); + } + tracing::info!("Bridge task sent all {} messages", BRIDGE_COUNT); + BRIDGE_COUNT + }); + + let node_handle = tokio::spawn(async move { + for (i, &delay_us) in node_delays.iter().enumerate() { + sleep(Duration::from_micros(delay_us)).await; + tracing::debug!( + "NodeController task sending message {} after {}us delay", + i, + delay_us + ); + node_tx + .send(NodeEvent::Disconnect { cause: None }) + .await + .unwrap(); + } + tracing::info!("NodeController task sent all {} messages", NODE_COUNT); + NODE_COUNT + }); + + let client_handle = tokio::spawn(async move { + for (i, &delay_us) in client_delays.iter().enumerate() { + sleep(Duration::from_micros(delay_us)).await; + let client_id = crate::client_events::ClientId::next(); + let waiting_tx = + crate::contract::WaitingTransaction::Transaction(Transaction::new::< + crate::operations::put::PutMsg, + >()); + tracing::debug!( + "Client task sending message {} after {}us delay", + i, + delay_us + ); + client_tx.send(Ok((client_id, waiting_tx))).await.unwrap(); + } + tracing::info!("Client task sent all {} messages", CLIENT_COUNT); + CLIENT_COUNT + }); + + let executor_handle = tokio::spawn(async move { + for (i, &delay_us) in executor_delays.iter().enumerate() { + sleep(Duration::from_micros(delay_us)).await; + tracing::debug!( + "Executor task sending message {} after {}us delay", + i, + delay_us + ); + executor_tx + .send(Ok(Transaction::new::())) + .await + .unwrap(); + } + tracing::info!("Executor task sent all {} messages", EXECUTOR_COUNT); + EXECUTOR_COUNT + }); + + // Wait a bit for senders to start sending (shorter delay since we're using microseconds now) + sleep(Duration::from_micros(100)).await; + + // Mock implementations for the stream + + struct MockClient { + rx: mpsc::Receiver< + Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + >, + >, + closed: bool, + } + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + if self.closed { + // Once closed, pend forever instead of returning error repeatedly + futures::future::pending::<()>().await; + unreachable!() + } + match self.rx.recv().await { + Some(result) => result, + None => { + self.closed = true; + Err(anyhow::anyhow!("closed")) + } + } + } + } + + struct MockExecutor { + rx: mpsc::Receiver>, + closed: bool, + } + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + if self.closed { + // Once closed, pend forever instead of returning error repeatedly + futures::future::pending::<()>().await; + unreachable!() + } + match self.rx.recv().await { + Some(result) => result, + None => { + self.closed = true; + Err(anyhow::anyhow!("closed")) + } + } + } + } + + // Create stream ONCE - it maintains waker registration and handles channel closures + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient { + rx: client_rx, + closed: false, + }, + MockExecutor { + rx: executor_rx, + closed: false, + }, + peers, + ); + tokio::pin!(stream); + + // Collect all messages from the event loop (run concurrently with senders) + let mut received_events = Vec::new(); + let mut iteration = 0; + + // Continue until we've received all expected messages + use futures::StreamExt; + while received_events.len() < TOTAL_MESSAGES { + // Poll the SAME stream on each iteration - maintains waker registration + let result = timeout(Duration::from_millis(100), stream.as_mut().next()).await; + assert!(result.is_ok(), "Iteration {} timed out", iteration); + + // Stream returns None when there are no more events + let Some(event) = result.unwrap() else { + tracing::debug!("Stream ended (all channels closed)"); + break; + }; + + // Check if this is a real message or a channel close + let (event_name, is_real_message) = match &event { + SelectResult::Notification(msg) => { + if msg.is_some() { + tracing::debug!("Received Notification message"); + ("notification", true) + } else { + tracing::debug!("Notification channel closed"); + ("notification", false) + } + } + SelectResult::OpExecution(msg) => { + if msg.is_some() { + tracing::debug!("Received OpExecution message"); + ("op_execution", true) + } else { + tracing::debug!("OpExecution channel closed"); + ("op_execution", false) + } + } + SelectResult::PeerConnection(msg) => ("peer_connection", msg.is_some()), + SelectResult::ConnBridge(msg) => { + if msg.is_some() { + tracing::debug!("Received ConnBridge message"); + ("conn_bridge", true) + } else { + tracing::debug!("ConnBridge channel closed"); + ("conn_bridge", false) + } + } + SelectResult::Handshake(_) => { + ("handshake", false) // No real messages on this channel in this test + } + SelectResult::NodeController(msg) => { + if msg.is_some() { + tracing::debug!("Received NodeController message"); + ("node_controller", true) + } else { + tracing::debug!("NodeController channel closed"); + ("node_controller", false) + } + } + SelectResult::ClientTransaction(result) => { + if result.is_ok() { + tracing::debug!("Received ClientTransaction message"); + ("client_transaction", true) + } else { + tracing::debug!("ClientTransaction channel closed or error"); + ("client_transaction", false) + } + } + SelectResult::ExecutorTransaction(result) => { + if result.is_ok() { + tracing::debug!("Received ExecutorTransaction message"); + ("executor_transaction", true) + } else { + tracing::debug!("ExecutorTransaction channel closed or error"); + ("executor_transaction", false) + } + } + }; + + // Only count real messages, not channel closures + if is_real_message { + received_events.push(event_name); + // Log every 100 messages to avoid spam with 1700 total messages + if received_events.len() % 100 == 0 { + tracing::info!( + "Received {} of {} real messages", + received_events.len(), + TOTAL_MESSAGES + ); + } + } else { + tracing::debug!( + "Iteration {}: Received channel close from {}", + iteration, + event_name + ); + } + + iteration += 1; + + // Safety check to prevent infinite loop + if iteration > TOTAL_MESSAGES * 3 { + tracing::error!("Receiver loop exceeded maximum iterations. Received {} of {} messages after {} iterations", + received_events.len(), TOTAL_MESSAGES, iteration); + panic!("Receiver loop exceeded maximum iterations - possible deadlock"); + } + } + + // Join all sender tasks and get the count of messages they sent + let sent_notif_count = notif_handle.await.unwrap(); + let sent_op_count = op_handle.await.unwrap(); + let sent_bridge_count = bridge_handle.await.unwrap(); + let sent_node_count = node_handle.await.unwrap(); + let sent_client_count = client_handle.await.unwrap(); + let sent_executor_count = executor_handle.await.unwrap(); + + let total_sent = sent_notif_count + + sent_op_count + + sent_bridge_count + + sent_node_count + + sent_client_count + + sent_executor_count; + tracing::info!("All sender tasks completed. Total sent: {}", total_sent); + tracing::info!( + "Receiver completed. Total received: {}", + received_events.len() + ); + + // Verify we received all expected messages + assert_eq!( + received_events.len(), + total_sent, + "Should receive all {} sent messages", + total_sent + ); + assert_eq!( + received_events.len(), + TOTAL_MESSAGES, + "Total received should match expected total" + ); + + // Count each received type + let recv_notif_count = received_events + .iter() + .filter(|&e| *e == "notification") + .count(); + let recv_op_count = received_events + .iter() + .filter(|&e| *e == "op_execution") + .count(); + let recv_bridge_count = received_events + .iter() + .filter(|&e| *e == "conn_bridge") + .count(); + let recv_node_count = received_events + .iter() + .filter(|&e| *e == "node_controller") + .count(); + let recv_client_count = received_events + .iter() + .filter(|&e| *e == "client_transaction") + .count(); + let recv_executor_count = received_events + .iter() + .filter(|&e| *e == "executor_transaction") + .count(); + + tracing::info!("Sent vs Received:"); + tracing::info!( + " notifications: sent={}, received={}", + sent_notif_count, + recv_notif_count + ); + tracing::info!( + " op_execution: sent={}, received={}", + sent_op_count, + recv_op_count + ); + tracing::info!( + " bridge: sent={}, received={}", + sent_bridge_count, + recv_bridge_count + ); + tracing::info!( + " node_controller: sent={}, received={}", + sent_node_count, + recv_node_count + ); + tracing::info!( + " client: sent={}, received={}", + sent_client_count, + recv_client_count + ); + tracing::info!( + " executor: sent={}, received={}", + sent_executor_count, + recv_executor_count + ); + + // Assert sent == received for each type + assert_eq!( + recv_notif_count, sent_notif_count, + "Notification count mismatch" + ); + assert_eq!(recv_op_count, sent_op_count, "OpExecution count mismatch"); + assert_eq!( + recv_bridge_count, sent_bridge_count, + "Bridge count mismatch" + ); + assert_eq!( + recv_node_count, sent_node_count, + "NodeController count mismatch" + ); + assert_eq!( + recv_client_count, sent_client_count, + "Client count mismatch" + ); + assert_eq!( + recv_executor_count, sent_executor_count, + "Executor count mismatch" + ); + + tracing::info!("✓ STRESS TEST PASSED for seed {}!", seed); + tracing::info!( + " All {} messages received correctly from 6 concurrent senders with random delays", + TOTAL_MESSAGES + ); + tracing::info!(" Received events: {:?}", received_events); + tracing::info!(" Priority ordering respected: when multiple messages buffered, highest priority selected first"); + } + + /// Test that verifies waker registration across ALL channels when they're all Pending + /// This is the critical behavior: when a PrioritySelectStream polls all 8 channels and they + /// all return Pending, it must register wakers for ALL of them, not just some. + #[tokio::test] + #[test_log::test] + async fn test_priority_select_all_pending_waker_registration() { + use futures::StreamExt; + + struct MockClient { + rx: mpsc::Receiver< + Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + >, + >, + closed: bool, + } + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + if self.closed { + // Once closed, pend forever instead of returning error repeatedly + futures::future::pending::<()>().await; + unreachable!() + } + match self.rx.recv().await { + Some(result) => result, + None => { + self.closed = true; + Err(anyhow::anyhow!("closed")) + } + } + } + } + + struct MockExecutor { + rx: mpsc::Receiver>, + closed: bool, + } + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + if self.closed { + // Once closed, pend forever instead of returning error repeatedly + futures::future::pending::<()>().await; + unreachable!() + } + match self.rx.recv().await { + Some(result) => result, + None => { + self.closed = true; + Err(anyhow::anyhow!("closed")) + } + } + } + } + + // Create all 8 channels + let (notif_tx, notif_rx) = mpsc::channel::>(10); + let (op_tx, op_rx) = + mpsc::channel::<(tokio::sync::mpsc::Sender, NetMessage)>(10); + let peers = FuturesUnordered::new(); + let (bridge_tx, bridge_rx) = mpsc::channel::(10); + let (node_tx, node_rx) = mpsc::channel::(10); + let (client_tx, client_rx) = mpsc::channel::< + Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + >, + >(10); + let (executor_tx, executor_rx) = mpsc::channel::>(10); + + // Start with NO messages buffered - this will cause all channels to return Pending on first poll + tracing::info!("Creating PrioritySelectStream with all channels empty"); + + // Spawn a task that will send messages after a delay + // This gives the stream time to poll all channels and register wakers + tokio::spawn(async move { + sleep(Duration::from_millis(10)).await; + tracing::info!("All wakers should now be registered, sending messages"); + + // Send to multiple channels simultaneously (in reverse priority order) + tracing::info!("Sending to executor channel (lowest priority)"); + executor_tx + .send(Ok(Transaction::new::())) + .await + .unwrap(); + + tracing::info!("Sending to client channel"); + let client_id = crate::client_events::ClientId::next(); + let waiting_tx = crate::contract::WaitingTransaction::Transaction(Transaction::new::< + crate::operations::put::PutMsg, + >()); + client_tx.send(Ok((client_id, waiting_tx))).await.unwrap(); + + tracing::info!("Sending to node controller channel"); + node_tx + .send(NodeEvent::Disconnect { cause: None }) + .await + .unwrap(); + + tracing::info!("Sending to bridge channel"); + bridge_tx + .send(Either::Right(NodeEvent::Disconnect { cause: None })) + .await + .unwrap(); + + tracing::info!("Sending to op execution channel (second priority)"); + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + let (callback_tx, _) = mpsc::channel(1); + op_tx.send((callback_tx, test_msg.clone())).await.unwrap(); + + tracing::info!("Sending to notification channel (highest priority)"); + notif_tx.send(Either::Left(test_msg)).await.unwrap(); + }); + + // Create the stream - it will poll all channels, find them all Pending, + // and register wakers for all of them + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient { + rx: client_rx, + closed: false, + }, + MockExecutor { + rx: executor_rx, + closed: false, + }, + peers, + ); + tokio::pin!(stream); + + // Poll the stream - it should wake up and return the NOTIFICATION (highest priority) + // despite all other channels also having messages + tracing::info!("PrioritySelectStream started, should poll all channels and go Pending"); + let result = timeout(Duration::from_millis(100), stream.next()).await; + assert!( + result.is_ok(), + "Select should wake up when any message arrives" + ); + + let select_result = result.unwrap().expect("Stream should yield value"); + match select_result { + SelectResult::Notification(_) => { + tracing::info!( + "✓ Correctly received Notification despite 5 other channels having messages" + ); + } + SelectResult::OpExecution(_) => { + panic!("Should prioritize Notification over OpExecution") + } + SelectResult::ConnBridge(_) => panic!("Should prioritize Notification over ConnBridge"), + SelectResult::NodeController(_) => { + panic!("Should prioritize Notification over NodeController") + } + SelectResult::ClientTransaction(_) => { + panic!("Should prioritize Notification over ClientTransaction") + } + SelectResult::ExecutorTransaction(_) => { + panic!("Should prioritize Notification over ExecutorTransaction") + } + _ => panic!("Unexpected result"), + } + } + + /// Test that reproduces the lost wakeup race condition from issue #1932 + /// + /// This test demonstrates the bug where recreating PrioritySelectFuture on every + /// iteration loses waker registration, causing messages to be missed. + /// + /// This test verifies the fix using PrioritySelectStream which maintains waker registration. + #[tokio::test] + #[test_log::test] + async fn test_sparse_messages_reproduce_race() { + tracing::info!( + "=== Testing sparse messages with PrioritySelectStream (verifying fix for #1932) ===" + ); + + // Mock implementations for testing + + struct MockClient; + impl ClientTransactionRelay for MockClient { + async fn relay_transaction_result_to_client( + &mut self, + ) -> Result< + ( + crate::client_events::ClientId, + crate::contract::WaitingTransaction, + ), + anyhow::Error, + > { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + struct MockExecutor; + impl ExecutorTransactionReceiver for MockExecutor { + async fn transaction_from_executor(&mut self) -> anyhow::Result { + sleep(Duration::from_secs(1000)).await; + Err(anyhow::anyhow!("closed")) + } + } + + let (notif_tx, notif_rx) = mpsc::channel::>(10); + let (_, op_rx) = mpsc::channel(1); + let peers = FuturesUnordered::new(); + let (_, bridge_rx) = mpsc::channel(1); + let (_, node_rx) = mpsc::channel(1); + + // Spawn sender that sends 5 messages with 200ms gaps + let sender = tokio::spawn(async move { + for i in 0..5 { + sleep(Duration::from_millis(200)).await; + tracing::info!( + "Sender: Sending message {} at {:?}", + i, + std::time::Instant::now() + ); + let test_msg = NetMessage::V1(crate::message::NetMessageV1::Aborted( + crate::message::Transaction::new::(), + )); + match notif_tx.send(Either::Left(test_msg)).await { + Ok(_) => tracing::info!("Sender: Message {} sent successfully", i), + Err(e) => { + tracing::error!("Sender: Failed to send message {}: {:?}", i, e); + break; + } + } + } + tracing::info!("Sender: Finished sending all messages"); + }); + + // Create the stream ONCE - this is the fix! + let stream = PrioritySelectStream::new( + notif_rx, + op_rx, + bridge_rx, + create_mock_handshake_stream(), + node_rx, + MockClient, + MockExecutor, + peers, + ); + tokio::pin!(stream); + + let mut received = 0; + let mut iteration = 0; + + // Receiver polls the SAME stream repeatedly (the fix - maintains waker registration) + while received < 5 && iteration < 20 { + iteration += 1; + tracing::info!( + "Iteration {}: Polling PrioritySelectStream (reusing same stream)", + iteration + ); + + match timeout(Duration::from_millis(300), stream.as_mut().next()).await { + Ok(Some(SelectResult::Notification(Some(_)))) => { + received += 1; + tracing::info!( + "✅ Iteration {}: Received message {} of 5", + iteration, + received + ); + } + Ok(Some(_)) => { + tracing::debug!("Iteration {}: Got other event", iteration); + } + Ok(None) => { + tracing::error!("Stream ended unexpectedly"); + break; + } + Err(_) => { + tracing::warn!("Iteration {}: Timeout waiting for message", iteration); + } + } + } + + // Wait for sender to finish + sender.await.unwrap(); + tracing::info!("Sender task completed, received {} messages", received); + + assert_eq!( + received, 5, + "❌ FAIL: PrioritySelectStream still lost messages! Expected 5 but received {} in {} iterations.\n\ + The fix should prevent lost wakeups by keeping the stream alive.", + received, iteration + ); + tracing::info!("✅ PASS: All 5 messages received without loss using PrioritySelectStream!"); + } + + /// Test that stream-based approach doesn't lose messages with sparse arrivals + /// This reproduces the race condition scenario but with the stream-based fix + #[tokio::test] + #[test_log::test] + async fn test_stream_no_lost_messages_sparse_arrivals() { + use tokio_stream::wrappers::ReceiverStream; + + tracing::info!("=== Testing stream approach doesn't lose messages (sparse arrivals) ==="); + + let (tx, rx) = mpsc::channel::(10); + + // Convert receiver to stream + let stream = ReceiverStream::new(rx); + + // Simple stream wrapper that yields items + struct MessageStream { + inner: S, + } + + impl Stream for MessageStream { + type Item = S::Item; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + Pin::new(&mut self.inner).poll_next(cx) + } + } + + let mut message_stream = MessageStream { inner: stream }; + + // Spawn sender that sends 5 messages with 200ms gaps (sparse arrivals) + let sender = tokio::spawn(async move { + for i in 0..5 { + sleep(Duration::from_millis(200)).await; + tracing::info!( + "Sender: Sending message {} at {:?}", + i, + std::time::Instant::now() + ); + tx.send(format!("msg{}", i)).await.unwrap(); + tracing::info!("Sender: Message {} sent successfully", i); + } + }); + + // Receiver loop: call stream.next().await repeatedly + // The stream should maintain waker registration across iterations + let mut received = 0; + for iteration in 1..=20 { + tracing::info!("Iteration {}: Calling stream.next().await", iteration); + + let msg = timeout(Duration::from_millis(300), message_stream.next()).await; + + match msg { + Ok(Some(msg)) => { + received += 1; + tracing::info!("✓ Received: {} (total: {})", msg, received); + } + Ok(None) => { + tracing::info!("Stream ended"); + break; + } + Err(_) => { + tracing::info!( + "Timeout on iteration {} (received {} so far)", + iteration, + received + ); + if received >= 5 { + break; // All messages received + } + } + } + } + + sender.await.unwrap(); + tracing::info!("Sender task completed, received {} messages", received); + + assert_eq!( + received, 5, + "Stream approach should receive ALL messages! Expected 5 but got {}.\n\ + The stream maintains waker registration across .next().await calls.", + received + ); + + tracing::info!( + "✓ SUCCESS: Stream-based approach received all 5 messages with sparse arrivals!" + ); + tracing::info!( + "✓ Waker registration was maintained across stream.next().await iterations!" + ); + } + + /// Test that recreating futures on each poll maintains waker registration + /// This tests the hypothesis for "special" types with async methods + #[tokio::test] + #[test_log::test] + async fn test_recreating_futures_maintains_waker() { + tracing::info!("=== Testing that recreating futures on each poll maintains waker ==="); + + // Mock "special" type with an async method and internal state + struct MockSpecial { + counter: std::sync::Arc>, + rx: tokio::sync::mpsc::Receiver, + } + + impl MockSpecial { + // Async method that borrows &mut self + async fn wait_for_event(&mut self) -> Option { + tracing::info!("MockSpecial::wait_for_event called"); + let msg = self.rx.recv().await?; + let mut counter = self.counter.lock().unwrap(); + *counter += 1; + tracing::info!("MockSpecial: received '{}', counter now {}", msg, *counter); + Some(msg) + } + } + + // Stream that owns MockSpecial and recreates futures on each poll + struct TestStream { + special: MockSpecial, + } + + impl Stream for TestStream { + type Item = String; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // KEY: Create fresh future on EVERY poll + let fut = self.special.wait_for_event(); + tokio::pin!(fut); + + match fut.poll(cx) { + Poll::Ready(Some(msg)) => Poll::Ready(Some(msg)), + Poll::Ready(None) => Poll::Ready(None), // Channel closed + Poll::Pending => Poll::Pending, + } + } + } + + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let (tx, rx) = mpsc::channel::(10); + + let mut test_stream = TestStream { + special: MockSpecial { + counter: counter.clone(), + rx, + }, + }; + + // Spawn sender with sparse arrivals (200ms gaps) + let sender = tokio::spawn(async move { + for i in 0..5 { + sleep(Duration::from_millis(200)).await; + tracing::info!("Sender: Sending message {}", i); + tx.send(format!("msg{}", i)).await.unwrap(); + } + }); + + // Receive using stream.next().await in loop + let mut received = 0; + for iteration in 1..=20 { + tracing::info!("Iteration {}: Calling stream.next().await", iteration); + + let msg = timeout(Duration::from_millis(300), test_stream.next()).await; + + match msg { + Ok(Some(msg)) => { + received += 1; + tracing::info!("✓ Received: {} (total: {})", msg, received); + } + Ok(None) => { + tracing::info!("Stream ended"); + break; + } + Err(_) => { + tracing::info!( + "Timeout on iteration {} (received {} so far)", + iteration, + received + ); + if received >= 5 { + break; + } + } + } + } + + sender.await.unwrap(); + + assert_eq!( + received, 5, + "Recreating futures on each poll should STILL receive all messages! Got {}", + received + ); + + let final_counter = *counter.lock().unwrap(); + assert_eq!(final_counter, 5, "Counter should be 5"); + + tracing::info!("✓ SUCCESS: Recreating futures on each poll MAINTAINS waker registration!"); + tracing::info!( + "✓ The stream struct staying alive is what matters, not the individual futures!" + ); + } + + /// Test that nested tokio::select! works correctly with stream approach + /// This is critical because HandshakeHandler::wait_for_events has a nested select! + /// + /// This verifies that even when async methods contain nested selects, + /// the stream maintains waker registration and doesn't lose messages. + #[tokio::test] + #[test_log::test] + async fn test_recreating_futures_with_nested_select() { + use futures::StreamExt; + + tracing::info!("=== Testing stream with NESTED select (like HandshakeHandler) ==="); + + // Mock type with nested select (simulating HandshakeHandler pattern) + struct MockWithNestedSelect { + rx1: tokio::sync::mpsc::Receiver, + rx2: tokio::sync::mpsc::Receiver, + counter: std::sync::Arc>, + } + + impl MockWithNestedSelect { + // Async method with nested tokio::select! (like wait_for_events) + async fn wait_for_event(&mut self) -> String { + // NESTED SELECT - just like HandshakeHandler::wait_for_events + tokio::select! { + msg1 = self.rx1.recv() => { + if let Some(msg) = msg1 { + let mut counter = self.counter.lock().unwrap(); + *counter += 1; + tracing::info!("Nested select: rx1 received '{}', counter {}", msg, *counter); + format!("rx1:{}", msg) + } else { + "rx1:closed".to_string() + } + } + msg2 = self.rx2.recv() => { + if let Some(msg) = msg2 { + let mut counter = self.counter.lock().unwrap(); + *counter += 1; + tracing::info!("Nested select: rx2 received '{}', counter {}", msg, *counter); + format!("rx2:{}", msg) + } else { + "rx2:closed".to_string() + } + } + } + } + } + + // Stream that creates fresh futures on each poll - just like PrioritySelectStream + struct TestStream { + special: MockWithNestedSelect, + } + + impl Stream for TestStream { + type Item = String; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + // Create fresh future on EVERY poll - this is what PrioritySelectStream does + let fut = self.special.wait_for_event(); + tokio::pin!(fut); + + match fut.poll(cx) { + Poll::Ready(msg) => Poll::Ready(Some(msg)), + Poll::Pending => Poll::Pending, + } + } + } + + let counter = std::sync::Arc::new(std::sync::Mutex::new(0)); + let (tx1, rx1) = mpsc::channel::(10); + let (tx2, rx2) = mpsc::channel::(10); + + // KEY FIX: Send all messages BEFORE starting to receive + // This eliminates the race between sender and receiver + for i in 0..3 { + if i % 2 == 0 { + tracing::info!("Sending to rx1: msg{}", i); + tx1.send(format!("msg{}", i)).await.unwrap(); + } else { + tracing::info!("Sending to rx2: msg{}", i); + tx2.send(format!("msg{}", i)).await.unwrap(); + } + } + tracing::info!("All 3 messages sent, now dropping senders"); + drop(tx1); + drop(tx2); + + // Create the stream ONCE and reuse it - key to maintaining waker registration + let test_stream = TestStream { + special: MockWithNestedSelect { + rx1, + rx2, + counter: counter.clone(), + }, + }; + tokio::pin!(test_stream); + + // Receive all messages + let mut received = Vec::new(); + for iteration in 1..=10 { + tracing::info!("Iteration {}: Calling stream.next().await", iteration); + + let msg = timeout(Duration::from_millis(100), test_stream.as_mut().next()).await; + + match msg { + Ok(Some(msg)) => { + if msg.contains("closed") { + tracing::info!("Channel closed: {}", msg); + // Continue to check if other channel has messages + continue; + } + received.push(msg.clone()); + tracing::info!("✓ Received: {} (total: {})", msg, received.len()); + + if received.len() >= 3 { + break; + } + } + Ok(None) => { + tracing::info!("Stream ended"); + break; + } + Err(_) => { + tracing::info!( + "Timeout on iteration {} (received {} so far)", + iteration, + received.len() + ); + break; + } + } + } + + assert_eq!( + received.len(), + 3, + "Stream with NESTED select should receive all messages! Got {} messages: {:?}", + received.len(), + received + ); + + let final_counter = *counter.lock().unwrap(); + assert_eq!(final_counter, 3, "Counter should be 3"); + + tracing::info!( + "✅ SUCCESS: Stream with NESTED select (like HandshakeHandler) maintains waker registration!" + ); + tracing::info!("✅ Received all messages: {:?}", received); + } + + /// Test the critical edge case: messages arrive with very tight timing + /// This simulates what happens in production when messages arrive rapidly + /// while the nested select is processing. + #[tokio::test] + #[test_log::test] + async fn test_nested_select_concurrent_arrivals() { + use futures::StreamExt; + + tracing::info!("=== Testing nested select with rapid concurrent arrivals ==="); + + struct MockWithNestedSelect { + rx1: tokio::sync::mpsc::Receiver, + rx2: tokio::sync::mpsc::Receiver, + } + + impl MockWithNestedSelect { + async fn wait_for_event(&mut self) -> String { + tokio::select! { + msg1 = self.rx1.recv() => { + if let Some(msg) = msg1 { + tracing::info!("Nested select: rx1 received '{}'", msg); + format!("rx1:{}", msg) + } else { + "rx1:closed".to_string() + } + } + msg2 = self.rx2.recv() => { + if let Some(msg) = msg2 { + tracing::info!("Nested select: rx2 received '{}'", msg); + format!("rx2:{}", msg) + } else { + "rx2:closed".to_string() + } + } + } + } + } + + struct TestStream { + special: MockWithNestedSelect, + } + + impl Stream for TestStream { + type Item = String; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let fut = self.special.wait_for_event(); + tokio::pin!(fut); + match fut.poll(cx) { + Poll::Ready(msg) => Poll::Ready(Some(msg)), + Poll::Pending => Poll::Pending, + } + } + } + + let (tx1, rx1) = mpsc::channel::(10); + let (tx2, rx2) = mpsc::channel::(10); + + let test_stream = TestStream { + special: MockWithNestedSelect { rx1, rx2 }, + }; + tokio::pin!(test_stream); + + // STRESS TEST: 1000 messages (100x more than original) + // Spawn a sender that rapidly sends messages alternating between channels + const MESSAGE_COUNT: usize = 1000; + tokio::spawn(async move { + for i in 0..MESSAGE_COUNT { + // Send to alternating channels with minimal delay + if i % 2 == 0 { + if i % 100 == 0 { + tracing::info!("Sending msg{} to rx1 ({} sent)", i, i); + } + tx1.send(format!("msg{}", i)).await.unwrap(); + } else { + if i % 100 == 0 { + tracing::info!("Sending msg{} to rx2 ({} sent)", i, i); + } + tx2.send(format!("msg{}", i)).await.unwrap(); + } + // Tiny delay to allow some interleaving and race conditions + sleep(Duration::from_micros(10)).await; + } + tracing::info!("Sender finished: sent all {} messages", MESSAGE_COUNT); + }); + + // Receive all messages - if wakers are maintained, we should get all 1000 + let mut received = Vec::new(); + for iteration in 0..(MESSAGE_COUNT + 100) { + match timeout(Duration::from_millis(100), test_stream.as_mut().next()).await { + Ok(Some(msg)) => { + if !msg.contains("closed") { + received.push(msg); + if received.len() % 100 == 0 { + tracing::info!( + "Received {} of {} messages", + received.len(), + MESSAGE_COUNT + ); + } + } + if received.len() >= MESSAGE_COUNT { + break; + } + } + Ok(None) => break, + Err(_) => { + tracing::info!( + "Timeout on iteration {} after receiving {} messages", + iteration, + received.len() + ); + break; + } + } + } + + assert_eq!( + received.len(), MESSAGE_COUNT, + "Should receive all {} messages even with rapid arrivals! Got {}. First 10: {:?}, Last 10: {:?}", + MESSAGE_COUNT, received.len(), + &received[..received.len().min(10)], + &received[received.len().saturating_sub(10)..] + ); + + tracing::info!("✅ SUCCESS: All {} rapid messages received!", MESSAGE_COUNT); + tracing::info!( + "✅ Nested select with stream maintains waker registration under high concurrent load!" + ); + } } diff --git a/crates/core/src/node/op_state_manager.rs b/crates/core/src/node/op_state_manager.rs index 6ee78b6eb..541a71c27 100644 --- a/crates/core/src/node/op_state_manager.rs +++ b/crates/core/src/node/op_state_manager.rs @@ -229,6 +229,15 @@ impl OpManager { } pub async fn push(&self, id: Transaction, op: OpEnum) -> Result<(), OpError> { + // Check if operation is already completed - don't push back to HashMap + if self.ops.completed.contains(&id) { + tracing::debug!( + tx = %id, + "OpManager: Ignoring push for already completed operation" + ); + return Ok(()); + } + if let Some(tx) = self.ops.under_progress.remove(&id) { if tx.timed_out() { self.ops.completed.insert(tx); diff --git a/crates/core/src/operations/put.rs b/crates/core/src/operations/put.rs index 9bb6ec748..16ad8e538 100644 --- a/crates/core/src/operations/put.rs +++ b/crates/core/src/operations/put.rs @@ -163,6 +163,7 @@ impl Operation for PutOp { match input { PutMsg::RequestPut { id, + sender, contract, related_contracts, value, @@ -171,7 +172,7 @@ impl Operation for PutOp { } => { // Get the contract key and own location let key = contract.key(); - let sender = op_manager.ring.connection_manager.own_location(); + let own_location = op_manager.ring.connection_manager.own_location(); tracing::info!( "Requesting put for contract {} from {} to {}", @@ -268,34 +269,41 @@ impl Operation for PutOp { // Create a SeekNode message to forward to the next hop return_msg = Some(PutMsg::SeekNode { id: *id, - sender, + sender: sender.clone(), target: forward_target, value: modified_value.clone(), contract: contract.clone(), related_contracts: related_contracts.clone(), htl: *htl, }); + + // Transition to AwaitingResponse state to handle future SuccessfulPut messages + new_state = Some(PutState::AwaitingResponse { + key, + upstream: Some(sender.clone()), + contract: contract.clone(), + state: modified_value, + subscribe: false, + }); } else { // No other peers to forward to - we're the final destination tracing::debug!( tx = %id, %key, - "No peers to forward to - handling PUT completion locally" + "No peers to forward to - handling PUT completion locally, sending SuccessfulPut back to sender" ); - return_msg = None; - } - // Transition to AwaitingResponse state to handle future SuccessfulPut messages - new_state = Some(PutState::AwaitingResponse { - key, - upstream: match &self.state { - Some(PutState::ReceivedRequest) => None, - _ => None, - }, - contract: contract.clone(), - state: modified_value, - subscribe: false, - }); + // Send SuccessfulPut back to the sender (upstream node) + return_msg = Some(PutMsg::SuccessfulPut { + id: *id, + target: sender.clone(), + key, + sender: own_location, + }); + + // Mark operation as finished + new_state = Some(PutState::Finished { key }); + } } PutMsg::SeekNode { id, @@ -1110,6 +1118,7 @@ pub(crate) async fn request_put(op_manager: &OpManager, mut put_op: PutOp) -> Re // Create RequestPut message and forward to target peer let msg = PutMsg::RequestPut { id, + sender: own_location, contract, related_contracts, value, @@ -1272,6 +1281,7 @@ mod messages { /// Internal node instruction to find a route to the target node. RequestPut { id: Transaction, + sender: PeerKeyLocation, contract: ContractContainer, #[serde(deserialize_with = "RelatedContracts::deser_related_contracts")] related_contracts: RelatedContracts<'static>, diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index 6af9af7af..e67be5c16 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -4,7 +4,7 @@ use freenet_stdlib::prelude::*; pub(crate) use self::messages::UpdateMsg; use super::{OpEnum, OpError, OpInitialization, OpOutcome, Operation, OperationResult}; -use crate::contract::ContractHandlerEvent; +use crate::contract::{ContractHandlerEvent, StoreResponse}; use crate::message::{InnerMessage, NetMessage, Transaction}; use crate::node::IsOperationCompleted; use crate::ring::{Location, PeerKeyLocation, RingError}; @@ -137,30 +137,170 @@ impl Operation for UpdateOp { UpdateMsg::RequestUpdate { id, key, + sender: request_sender, target, related_contracts, value, } => { - let sender = op_manager.ring.connection_manager.own_location(); + let self_location = op_manager.ring.connection_manager.own_location(); tracing::debug!( - "UPDATE RequestUpdate: forwarding update for contract {} from {} to {}", - key, - sender.peer, + tx = %id, + %key, + executing_peer = %self_location.peer, + request_sender = %request_sender.peer, + target_peer = %target.peer, + "UPDATE RequestUpdate: executing_peer={} received request from={} targeting={}", + self_location.peer, + request_sender.peer, target.peer ); - return_msg = Some(UpdateMsg::SeekNode { - id: *id, - sender, - target: target.clone(), - value: value.clone(), - key: *key, - related_contracts: related_contracts.clone(), - }); + // If target is not us, this message is meant for another peer + // This can happen when the initiator processes its own message before sending to network + if target.peer != self_location.peer { + tracing::debug!( + tx = %id, + %key, + our_peer = %self_location.peer, + target_peer = %target.peer, + "RequestUpdate target is not us - message will be routed to target" + ); + // Keep current state, message will be sent to target peer via network + return_msg = None; + new_state = self.state; + } else { + // Target is us - process the request + // Determine if this is a local request (from our own client) or remote request + let is_local_request = + matches!(&self.state, Some(UpdateState::PrepareRequest { .. })) + || matches!( + &self.state, + Some(UpdateState::AwaitingResponse { upstream: None, .. }) + ); + let upstream = if is_local_request { + None // No upstream - we are the initiator + } else { + Some(request_sender.clone()) // Upstream is the peer that sent us this request + }; - // no changes to state yet, still in AwaitResponse state - new_state = self.state; + // First check if we have the contract locally + let has_contract = match op_manager + .notify_contract_handler(ContractHandlerEvent::GetQuery { + key: *key, + return_contract_code: false, + }) + .await + { + Ok(ContractHandlerEvent::GetResponse { + response: Ok(StoreResponse { state: Some(_), .. }), + .. + }) => { + tracing::debug!(tx = %id, %key, "Contract exists locally, handling UPDATE"); + true + } + _ => { + tracing::debug!(tx = %id, %key, "Contract not found locally"); + false + } + }; + + if has_contract { + // We have the contract - handle UPDATE locally + tracing::debug!( + tx = %id, + %key, + "Handling UPDATE locally - contract exists" + ); + + // Update contract locally + let updated_value = update_contract( + op_manager, + *key, + value.clone(), + related_contracts.clone(), + ) + .await?; + + // Get broadcast targets for propagating UPDATE to subscribers + let broadcast_to = + op_manager.get_broadcast_targets_update(key, &request_sender.peer); + + // Create success message to send back + let raw_state = State::from(updated_value); + let summary = StateSummary::from(raw_state.into_bytes()); + + if broadcast_to.is_empty() { + // No peers to broadcast to - just send success response + return_msg = Some(UpdateMsg::SuccessfulUpdate { + id: *id, + target: request_sender.clone(), + summary: summary.clone(), + key: *key, + sender: self_location.clone(), + }); + new_state = Some(UpdateState::Finished { key: *key, summary }); + } else { + // Broadcast to other peers + match try_to_broadcast( + *id, + true, // last_hop - we're handling locally + op_manager, + self.state, + (broadcast_to, request_sender.clone()), + *key, + value.clone(), + false, + ) + .await + { + Ok((state, msg)) => { + new_state = state; + return_msg = msg; + } + Err(err) => return Err(err), + } + } + } else { + // Contract not found locally - forward to another peer + let next_target = op_manager.ring.closest_potentially_caching( + key, + [&self_location.peer, &request_sender.peer].as_slice(), + ); + + if let Some(forward_target) = next_target { + tracing::debug!( + tx = %id, + %key, + next_peer = %forward_target.peer, + "Forwarding UPDATE to peer that might have contract" + ); + + // Create a SeekNode message to forward to the next hop + return_msg = Some(UpdateMsg::SeekNode { + id: *id, + sender: self_location.clone(), + target: forward_target, + value: value.clone(), + key: *key, + related_contracts: related_contracts.clone(), + }); + // Transition to AwaitingResponse state to wait for SuccessfulUpdate + new_state = Some(UpdateState::AwaitingResponse { + key: *key, + upstream, + }); + } else { + // No peers available and we don't have the contract - error + tracing::error!( + tx = %id, + %key, + "Cannot handle UPDATE: contract not found locally and no peers to forward to" + ); + return Err(OpError::RingError(RingError::NoCachingPeers(*key))); + } + } + } } UpdateMsg::SeekNode { id, @@ -170,67 +310,126 @@ impl Operation for UpdateOp { target, sender, } => { - // Check if we're seeding or subscribed to this contract - let is_seeding = op_manager.ring.is_seeding_contract(key); - let has_subscribers = op_manager.ring.subscribers_of(key).is_some(); - let should_handle_update = is_seeding || has_subscribers; - - tracing::info!( - "UPDATE_RECEIVED: tx={} contract={:.8} from={} at={} seeding={} subscribers={}", - id, - key, - sender.peer, - target.peer, - is_seeding, - has_subscribers - ); - - tracing::debug!( - tx = %id, - %key, - target = %target.peer, - sender = %sender.peer, - is_seeding = %is_seeding, - has_subscribers = %has_subscribers, - "Updating contract at target peer", - ); - - let broadcast_to = op_manager.get_broadcast_targets_update(key, &sender.peer); + // Check if we have the contract locally + let has_contract = match op_manager + .notify_contract_handler(ContractHandlerEvent::GetQuery { + key: *key, + return_contract_code: false, + }) + .await + { + Ok(ContractHandlerEvent::GetResponse { + response: Ok(StoreResponse { state: Some(_), .. }), + .. + }) => { + tracing::debug!(tx = %id, %key, "Contract exists locally for SeekNode UPDATE"); + true + } + _ => { + tracing::debug!(tx = %id, %key, "Contract not found locally for SeekNode UPDATE"); + false + } + }; - if should_handle_update { - tracing::debug!( - "Peer is seeding or has subscribers for contract. About to update it" - ); - update_contract(op_manager, *key, value.clone(), related_contracts.clone()) - .await?; + if has_contract { + tracing::debug!("Contract found locally - handling UPDATE"); + let updated_value = update_contract( + op_manager, + *key, + value.clone(), + related_contracts.clone(), + ) + .await?; tracing::debug!( tx = %id, "Successfully updated a value for contract {} @ {:?} - update", key, target.location ); + + // Get broadcast targets + let broadcast_to = + op_manager.get_broadcast_targets_update(key, &sender.peer); + + // If no peers to broadcast to, send success response directly + if broadcast_to.is_empty() { + tracing::debug!( + tx = %id, + %key, + "No broadcast targets for SeekNode - completing with SuccessfulUpdate" + ); + + // Create success message to send back to sender + let raw_state = State::from(updated_value); + let summary = StateSummary::from(raw_state.into_bytes()); + + return_msg = Some(UpdateMsg::SuccessfulUpdate { + id: *id, + target: sender.clone(), + summary, + key: *key, + sender: op_manager.ring.connection_manager.own_location(), + }); + new_state = None; + } else { + // Have peers to broadcast to - use try_to_broadcast + match try_to_broadcast( + *id, + true, + op_manager, + self.state, + (broadcast_to, sender.clone()), + *key, + value.clone(), + false, + ) + .await + { + Ok((state, msg)) => { + new_state = state; + return_msg = msg; + } + Err(err) => return Err(err), + } + } } else { - tracing::debug!("contract not found in this peer (not seeding and no subscribers). Should throw an error"); - return Err(OpError::RingError(RingError::NoCachingPeers(*key))); - } + // Contract not found - forward to another peer + let self_location = op_manager.ring.connection_manager.own_location(); + let next_target = op_manager.ring.closest_potentially_caching( + key, + [&sender.peer, &self_location.peer].as_slice(), + ); - match try_to_broadcast( - *id, - true, - op_manager, - self.state, - (broadcast_to, sender.clone()), - *key, - value.clone(), - false, - ) - .await - { - Ok((state, msg)) => { - new_state = state; - return_msg = msg; + if let Some(forward_target) = next_target { + tracing::debug!( + tx = %id, + %key, + next_peer = %forward_target.peer, + "Contract not found - forwarding SeekNode to next peer" + ); + + // Forward SeekNode to the next peer + return_msg = Some(UpdateMsg::SeekNode { + id: *id, + sender: self_location.clone(), + target: forward_target, + value: value.clone(), + key: *key, + related_contracts: related_contracts.clone(), + }); + new_state = Some(UpdateState::AwaitingResponse { + key: *key, + upstream: Some(sender.clone()), + }); + } else { + // No more peers to try - error + tracing::error!( + tx = %id, + %key, + "Cannot handle UPDATE SeekNode: contract not found and no peers to forward to" + ); + return Err(OpError::RingError(RingError::NoCachingPeers(*key))); } - Err(err) => return Err(err), } } UpdateMsg::BroadcastTo { @@ -671,11 +870,20 @@ pub(crate) async fn request_update( // the initial request must provide: // - a peer as close as possible to the contract location // - and the value to update - let target = if let Some(location) = op_manager.ring.subscribers_of(&key) { - location - .clone() - .pop() - .ok_or(OpError::RingError(RingError::NoLocation))? + let target_from_subscribers = if let Some(subscribers) = op_manager.ring.subscribers_of(&key) { + // Clone and filter out self from subscribers to prevent self-targeting + let mut filtered_subscribers: Vec<_> = subscribers + .iter() + .filter(|sub| sub.peer != sender.peer) + .cloned() + .collect(); + filtered_subscribers.pop() + } else { + None + }; + + let target = if let Some(remote_subscriber) = target_from_subscribers { + remote_subscriber } else { // Find the best peer to send the update to let remote_target = op_manager @@ -686,7 +894,7 @@ pub(crate) async fn request_update( // Subscribe to the contract op_manager .ring - .add_subscriber(&key, sender) + .add_subscriber(&key, sender.clone()) .map_err(|_| RingError::NoCachingPeers(key))?; target @@ -815,6 +1023,7 @@ pub(crate) async fn request_update( let msg = UpdateMsg::RequestUpdate { id, key, + sender, related_contracts, target, value, @@ -855,6 +1064,7 @@ mod messages { RequestUpdate { id: Transaction, key: ContractKey, + sender: PeerKeyLocation, target: PeerKeyLocation, #[serde(deserialize_with = "RelatedContracts::deser_related_contracts")] related_contracts: RelatedContracts<'static>, @@ -938,7 +1148,9 @@ mod messages { impl UpdateMsg { pub fn sender(&self) -> Option<&PeerKeyLocation> { match self { + Self::RequestUpdate { sender, .. } => Some(sender), Self::SeekNode { sender, .. } => Some(sender), + Self::SuccessfulUpdate { sender, .. } => Some(sender), Self::BroadcastTo { sender, .. } => Some(sender), _ => None, } diff --git a/crates/core/src/ring/connection.rs b/crates/core/src/ring/connection.rs index 27b406ca1..7b017b7d8 100644 --- a/crates/core/src/ring/connection.rs +++ b/crates/core/src/ring/connection.rs @@ -7,23 +7,8 @@ pub struct Connection { pub(crate) open_at: Instant, } -#[cfg(test)] -use super::Location; -#[cfg(test)] -use crate::node::PeerId; - #[cfg(test)] impl Connection { - pub fn new(peer: PeerId, location: Location) -> Self { - Connection { - location: PeerKeyLocation { - peer, - location: Some(location), - }, - open_at: Instant::now(), - } - } - pub fn get_location(&self) -> &PeerKeyLocation { &self.location }