Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
5a64e7e
webrtc: Implement proper FIN/FIN_ACK handshake and fix flag handling bug
timwu20 Jan 8, 2026
057eac8
webrtc: Unify event variants and encoding to use optional flag parameter
timwu20 Jan 9, 2026
0027ec7
webrtc: Implement proper shutdown with FIN/FIN_ACK handshake and wait…
timwu20 Jan 9, 2026
6de9d33
webrtc: Add timeout for FIN_ACK to prevent indefinite waiting
timwu20 Jan 9, 2026
4c5917d
webrtc: Fix clippy and fmt
timwu20 Jan 12, 2026
4695143
webrtc: Fix FIN_ACK response to be sent outbound to network instead o…
timwu20 Jan 12, 2026
8578487
webrtc: Fix data channel closure by detecting substream drop regardle…
timwu20 Jan 13, 2026
c7f4ea8
webrtc: Use protobuf-generated Flag enum instead of i32
timwu20 Jan 22, 2026
833da9b
Update src/transport/webrtc/substream.rs
timwu20 Jan 22, 2026
35b99f4
Update src/transport/webrtc/substream.rs
timwu20 Jan 22, 2026
42de0ab
webrtc: Use AtomicWaker instead of Mutex<Option<Waker>> for shutdown_…
timwu20 Jan 22, 2026
2ec58e4
webrtc: Use pinned Sleep instead of tokio::spawn for FIN_ACK timeout
timwu20 Jan 22, 2026
a4effe2
webrtc: Optimize WebRtcMessage encode/decode to reduce allocations
timwu20 Jan 22, 2026
73f76ec
webrtc: Wake blocked writers and close both sides on RESET_STREAM
timwu20 Jan 22, 2026
7aa6084
webrtc: Fix race condition in shutdown waker registration
timwu20 Jan 22, 2026
e3a436e
webrtc: Log warning for unknown flag values in message decoding
timwu20 Jan 22, 2026
0a24d6b
webrtc: Fix typo in WebRtcMessage doc comment
timwu20 Jan 22, 2026
1cc104a
webrtc: Address review feedback for FIN/FIN_ACK handshake
timwu20 Jan 23, 2026
2440797
webrtc: Additional review feedback fixes
timwu20 Jan 23, 2026
a36a04e
webrtc: Fix varint length calculation to use ilog2
timwu20 Jan 23, 2026
d537573
webrtc: Add test for FIN with payload and clarify race condition comm…
timwu20 Jan 23, 2026
184971c
webrtc: Remove separate test timeout
timwu20 Jan 23, 2026
c2ead44
webrtc: Fix shutdown timeout test to wait for 5s FIN_ACK timeout
timwu20 Jan 26, 2026
9ecfe8c
webrtc: Send RESET_STREAM on abrupt substream close
timwu20 Jan 15, 2026
002075c
webrtc: Make multistream-select negotiation spec compliant on outboun…
haikoschol Feb 18, 2026
bd50c8c
webrtc: Make multistream-select negotiation spec compliant on inbound…
haikoschol Feb 18, 2026
1c884a3
webrtc: Refactor webrtc_encode_multistream_message to take a single M…
haikoschol Feb 18, 2026
87dadce
Merge branch 'master' into haiko-webrtc-multistream-nego-fix-2
haikoschol Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 127 additions & 71 deletions src/multistream_select/dialer_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ use crate::{
codec::unsigned_varint::UnsignedVarint,
error::{self, Error, ParseError, SubstreamError},
multistream_select::{
drain_trailing_protocols,
protocol::{
webrtc_encode_multistream_message, HeaderLine, Message, MessageIO, Protocol,
ProtocolError, PROTO_MULTISTREAM_1_0,
Expand Down Expand Up @@ -300,6 +299,12 @@ pub enum HandshakeResult {
/// The returned tuple contains the negotiated protocol and response
/// that must be sent to remote peer.
Succeeded(ProtocolName),

/// The proposed protocol was rejected by the remote peer.
///
/// The caller should check if there are remaining fallback protocols to try
/// via [`WebRtcDialerState::propose_next_fallback()`].
Rejected,
}

/// Handshake state.
Expand Down Expand Up @@ -334,12 +339,9 @@ impl WebRtcDialerState {
protocol: ProtocolName,
fallback_names: Vec<ProtocolName>,
) -> crate::Result<(Self, Vec<u8>)> {
let message = webrtc_encode_multistream_message(
std::iter::once(protocol.clone())
.chain(fallback_names.clone())
.filter_map(|protocol| Protocol::try_from(protocol.as_ref()).ok())
.map(Message::Protocol),
)?
let message = webrtc_encode_multistream_message(Message::Protocol(
Protocol::try_from(protocol.as_ref()).map_err(|_| Error::InvalidData)?,
))?
.freeze()
.to_vec();

Expand All @@ -353,72 +355,83 @@ impl WebRtcDialerState {
))
}

/// Propose the next fallback protocol to the remote peer.
///
/// Returns `None` if there are no more fallback protocols to try.
/// Returns `Some(message)` with the encoded message to send, containing the protocol name.
pub fn propose_next_fallback(&mut self) -> crate::Result<Option<Vec<u8>>> {
if self.fallback_names.is_empty() {
return Ok(None);
}

let next = self.fallback_names.remove(0);
self.protocol = next;
self.state = HandshakeState::WaitingResponse;

let message = webrtc_encode_multistream_message(Message::Protocol(
Protocol::try_from(self.protocol.as_ref()).map_err(|_| Error::InvalidData)?,
))?
.freeze()
.to_vec();

Ok(Some(message))
}

/// Register response to [`WebRtcDialerState`].
pub fn register_response(
&mut self,
payload: Vec<u8>,
) -> Result<HandshakeResult, crate::error::NegotiationError> {
// All multistream-select messages are length-prefixed. Since this code path is not using
// multistream_select::protocol::MessageIO, we need to decode and remove the length here.
let remaining: &[u8] = &payload;
let (len, tail) = unsigned_varint::decode::usize(remaining).map_err(|error| {
tracing::debug!(
let bytes = Bytes::from(payload);
let mut remaining = bytes.clone();

while !remaining.is_empty() {
let (len, tail) = unsigned_varint::decode::usize(&remaining).map_err(|error| {
tracing::debug!(
target: LOG_TARGET,
?error,
message = ?payload,
"Failed to decode length-prefix in multistream message");
error::NegotiationError::ParseError(ParseError::InvalidData)
})?;
message = ?remaining,
"Failed to decode length-prefix in multistream message",
);
error::NegotiationError::ParseError(ParseError::InvalidData)
})?;

let len_size = remaining.len() - tail.len();
let bytes = Bytes::from(payload);
let payload = bytes.slice(len_size..len_size + len);
let remaining = bytes.slice(len_size + len..);
let message = Message::decode(payload);

tracing::trace!(
target: LOG_TARGET,
?message,
"Decoded message while registering response",
);

let mut protocols = match message {
Ok(Message::Header(HeaderLine::V1)) => {
vec![PROTO_MULTISTREAM_1_0]
let len_size = remaining.len() - tail.len();

if len > tail.len() {
tracing::debug!(
target: LOG_TARGET,
message = ?tail,
length_prefix = len,
actual_length = tail.len(),
"Truncated multistream message",
);
return Err(error::NegotiationError::ParseError(ParseError::InvalidData));
}
Ok(Message::Protocol(protocol)) => vec![protocol],
Ok(Message::Protocols(protocols)) => protocols,
Ok(Message::NotAvailable) =>
return match &self.state {
HandshakeState::WaitingProtocol => Err(
error::NegotiationError::MultistreamSelectError(NegotiationError::Failed),
),
_ => Err(error::NegotiationError::StateMismatch),
},
Ok(Message::ListProtocols) => return Err(error::NegotiationError::StateMismatch),
Err(_) => return Err(error::NegotiationError::ParseError(ParseError::InvalidData)),
};

match drain_trailing_protocols(remaining) {
Ok(protos) => protocols.extend(protos),
Err(error) => return Err(error),
}
let payload = remaining.slice(len_size..len_size + len);
remaining = remaining.slice(len_size + len..);
let message = Message::decode(payload);

let mut protocol_iter = protocols.into_iter();
loop {
match (&self.state, protocol_iter.next()) {
(HandshakeState::WaitingResponse, None) =>
return Err(crate::error::NegotiationError::StateMismatch),
(HandshakeState::WaitingResponse, Some(protocol)) => {
if protocol == PROTO_MULTISTREAM_1_0 {
self.state = HandshakeState::WaitingProtocol;
} else {
return Err(crate::error::NegotiationError::MultistreamSelectError(
NegotiationError::Failed,
));
}
tracing::trace!(
target: LOG_TARGET,
?message,
"Decoded message while registering response",
);

match (&self.state, message) {
(HandshakeState::WaitingResponse, Ok(Message::Header(HeaderLine::V1))) => {
self.state = HandshakeState::WaitingProtocol;
}
(HandshakeState::WaitingProtocol, Some(protocol)) => {
(HandshakeState::WaitingResponse, Ok(Message::Protocol(_))) => {
return Err(crate::error::NegotiationError::MultistreamSelectError(
NegotiationError::Failed,
));
}
(_, Ok(Message::NotAvailable)) => {
return Ok(HandshakeResult::Rejected);
}
(HandshakeState::WaitingProtocol, Ok(Message::Protocol(protocol))) => {
if protocol == PROTO_MULTISTREAM_1_0 {
return Err(crate::error::NegotiationError::StateMismatch);
}
Expand All @@ -437,11 +450,16 @@ impl WebRtcDialerState {
NegotiationError::Failed,
));
}
(HandshakeState::WaitingProtocol, None) => {
return Ok(HandshakeResult::NotReady);
_ => {
return Err(crate::error::NegotiationError::StateMismatch);
}
}
}

match &self.state {
HandshakeState::WaitingProtocol => Ok(HandshakeResult::NotReady),
HandshakeState::WaitingResponse => Err(crate::error::NegotiationError::StateMismatch),
}
}
}

Expand Down Expand Up @@ -816,6 +834,7 @@ mod tests {
)
.unwrap();

// Initial message should only contain the main protocol, not the fallback.
let mut bytes = BytesMut::with_capacity(32);
bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8);
let _ = Message::Header(HeaderLine::V1).encode(&mut bytes).unwrap();
Expand All @@ -824,15 +843,52 @@ mod tests {
bytes.put_u8((proto1.as_ref().len() + 1) as u8); // + 1 for \n
let _ = Message::Protocol(proto1).encode(&mut bytes).unwrap();

let proto2 = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name");
bytes.put_u8((proto2.as_ref().len() + 1) as u8); // + 1 for \n
let _ = Message::Protocol(proto2).encode(&mut bytes).unwrap();

let expected_message = bytes.freeze().to_vec();

assert_eq!(message, expected_message);
}

#[test]
fn propose_next_fallback() {
let (mut dialer_state, _message) = WebRtcDialerState::propose(
ProtocolName::from("/13371338/proto/1"),
vec![ProtocolName::from("/sup/proto/1")],
)
.unwrap();

// Simulate receiving header-only response, transitioning to WaitingProtocol.
let mut header_bytes = BytesMut::with_capacity(32);
header_bytes.put_u8(MSG_MULTISTREAM_1_0.len() as u8);
let _ = Message::Header(HeaderLine::V1).encode(&mut header_bytes).unwrap();
// Append "na" to simulate rejection.
let na_bytes = b"na\n";
header_bytes.put_u8(na_bytes.len() as u8);
header_bytes.put_slice(na_bytes);

match dialer_state.register_response(header_bytes.freeze().to_vec()) {
Ok(HandshakeResult::Rejected) => {}
event => panic!("expected Rejected, got: {event:?}"),
}

// Now propose the next fallback.
let fallback_message = dialer_state
.propose_next_fallback()
.expect("no error")
.expect("should have a fallback");

let mut expected = BytesMut::with_capacity(32);
expected.put_u8(MSG_MULTISTREAM_1_0.len() as u8);
let _ = Message::Header(HeaderLine::V1).encode(&mut expected).unwrap();
let proto = Protocol::try_from(&b"/sup/proto/1"[..]).expect("valid protocol name");
expected.put_u8((proto.as_ref().len() + 1) as u8);
let _ = Message::Protocol(proto).encode(&mut expected).unwrap();

assert_eq!(fallback_message, expected.freeze().to_vec());

// No more fallbacks.
assert!(dialer_state.propose_next_fallback().unwrap().is_none());
}

#[test]
fn register_response_header_only() {
let mut bytes = BytesMut::with_capacity(32);
Expand Down Expand Up @@ -875,9 +931,9 @@ mod tests {

#[test]
fn negotiate_main_protocol() {
let message = webrtc_encode_multistream_message(vec![Message::Protocol(
let message = webrtc_encode_multistream_message(Message::Protocol(
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
)])
))
.unwrap()
.freeze();

Expand All @@ -897,9 +953,9 @@ mod tests {

#[test]
fn negotiate_fallback_protocol() {
let message = webrtc_encode_multistream_message(vec![Message::Protocol(
let message = webrtc_encode_multistream_message(Message::Protocol(
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
)])
))
.unwrap()
.freeze();

Expand Down
24 changes: 10 additions & 14 deletions src/multistream_select/listener_select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,9 +374,7 @@ pub fn webrtc_listener_negotiate<'a>(
if protocol.as_ref() == supported.as_bytes() {
return Ok(ListenerSelectResult::Accepted {
protocol: supported.clone(),
message: webrtc_encode_multistream_message(std::iter::once(
Message::Protocol(protocol),
))?,
message: webrtc_encode_multistream_message(Message::Protocol(protocol))?,
});
}
}
Expand All @@ -388,7 +386,7 @@ pub fn webrtc_listener_negotiate<'a>(
);

Ok(ListenerSelectResult::Rejected {
message: webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))?,
message: webrtc_encode_multistream_message(Message::NotAvailable)?,
})
}

Expand All @@ -407,10 +405,9 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = webrtc_encode_multistream_message(vec![
Message::Protocol(Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap()),
Message::Protocol(Protocol::try_from(&b"/sup/proto/1"[..]).unwrap()),
])
let message = webrtc_encode_multistream_message(Message::Protocol(
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
))
.unwrap()
.freeze();

Expand Down Expand Up @@ -447,10 +444,10 @@ mod tests {
// `webrtc_listener_negotiate()` should reject this invalid message. The error can either be
// `InvalidData` because the message is malformed or `StateMismatch` because the message is
// not expected at this point in the protocol.
let message = webrtc_encode_multistream_message(std::iter::once(Message::Protocols(vec![
let message = webrtc_encode_multistream_message(Message::Protocols(vec![
Protocol::try_from(&b"/13371338/proto/1"[..]).unwrap(),
Protocol::try_from(&b"/sup/proto/1"[..]).unwrap(),
])))
]))
.unwrap()
.freeze();

Expand Down Expand Up @@ -534,9 +531,9 @@ mod tests {
ProtocolName::from("/13371338/proto/3"),
ProtocolName::from("/13371338/proto/4"),
];
let message = webrtc_encode_multistream_message(vec![Message::Protocol(
let message = webrtc_encode_multistream_message(Message::Protocol(
Protocol::try_from(&b"/13371339/proto/1"[..]).unwrap(),
)])
))
.unwrap()
.freeze();

Expand All @@ -545,8 +542,7 @@ mod tests {
Ok(ListenerSelectResult::Rejected { message }) => {
assert_eq!(
message,
webrtc_encode_multistream_message(std::iter::once(Message::NotAvailable))
.unwrap()
webrtc_encode_multistream_message(Message::NotAvailable).unwrap()
);
}
Ok(ListenerSelectResult::Accepted { protocol, message }) => panic!("message accepted"),
Expand Down
29 changes: 13 additions & 16 deletions src/multistream_select/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,24 +234,21 @@ impl Message {
///
/// This implementation may not be compliant with the multistream-select protocol spec.
/// The only purpose of this was to get the `multistream-select` protocol working with smoldot.
pub fn webrtc_encode_multistream_message(
messages: impl IntoIterator<Item = Message>,
) -> crate::Result<BytesMut> {
pub fn webrtc_encode_multistream_message(message: Message) -> crate::Result<BytesMut> {
// encode `/multistream-select/1.0.0` header
let mut bytes = BytesMut::with_capacity(32);
let message = Message::Header(HeaderLine::V1);
message.encode(&mut bytes).map_err(|_| Litep2pError::InvalidData)?;
let mut header = UnsignedVarint::encode(bytes)?;

// encode each message
for message in messages {
let mut proto_bytes = BytesMut::with_capacity(256);
message.encode(&mut proto_bytes).map_err(|_| Litep2pError::InvalidData)?;
let mut proto_bytes = UnsignedVarint::encode(proto_bytes)?;
header.append(&mut proto_bytes);
}

Ok(BytesMut::from(&header[..]))
Message::Header(HeaderLine::V1)
.encode(&mut bytes)
.map_err(|_| Litep2pError::InvalidData)?;
let mut output = UnsignedVarint::encode(bytes)?;

// encode the message
let mut msg_bytes = BytesMut::with_capacity(256);
message.encode(&mut msg_bytes).map_err(|_| Litep2pError::InvalidData)?;
let mut msg_bytes = UnsignedVarint::encode(msg_bytes)?;
output.append(&mut msg_bytes);

Ok(BytesMut::from(&output[..]))
}

/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
Expand Down
Loading