diff --git a/src/substream/mod.rs b/src/substream/mod.rs index 5fc1ec3c7..37ba98c05 100644 --- a/src/substream/mod.rs +++ b/src/substream/mod.rs @@ -369,7 +369,34 @@ impl Substream { io.write_all(&payload) .await - .map_err(|_| Error::SubstreamError(SubstreamError::ConnectionClosed)) + .map_err(|_| Error::SubstreamError(SubstreamError::ConnectionClosed))?; + + // Flush the stream. + io.flush().await.map_err(From::from) + } + + /// Send unsigned varint payload to remote peer. + async fn send_unsigned_varint_payload( + io: &mut T, + bytes: Bytes, + max_size: Option, + ) -> crate::Result<()> { + if let Some(max_size) = max_size { + if bytes.len() > max_size { + return Err(Error::IoError(ErrorKind::PermissionDenied)); + } + } + + // Write the length of the frame. + let mut buffer = unsigned_varint::encode::usize_buffer(); + let encoded_len = unsigned_varint::encode::usize(bytes.len(), &mut buffer).len(); + io.write_all(&buffer[..encoded_len]).await?; + + // Write the frame. + io.write_all(bytes.as_ref()).await?; + + // Flush the stream. + io.flush().await.map_err(From::from) } /// Send framed data to remote peer. @@ -386,7 +413,7 @@ impl Substream { /// # Panics /// /// Panics if no codec is provided. - pub async fn send_framed(&mut self, mut bytes: Bytes) -> crate::Result<()> { + pub async fn send_framed(&mut self, bytes: Bytes) -> crate::Result<()> { tracing::trace!( target: LOG_TARGET, peer = ?self.peer, @@ -403,48 +430,16 @@ impl Substream { ProtocolCodec::Unspecified => panic!("codec is unspecified"), ProtocolCodec::Identity(payload_size) => Self::send_identity_payload(substream, payload_size, bytes).await, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, bytes.len()); - - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); - let mut offset = 0; - - while offset < len.len() { - offset += substream.write(&len[offset..]).await?; - } - - while bytes.has_remaining() { - let nwritten = substream.write(&bytes).await?; - bytes.advance(nwritten); - } - - substream.flush().await.map_err(From::from) - } + ProtocolCodec::UnsignedVarint(max_size) => + Self::send_unsigned_varint_payload(substream, bytes, max_size).await, }, #[cfg(feature = "websocket")] SubstreamType::WebSocket(ref mut substream) => match self.codec { ProtocolCodec::Unspecified => panic!("codec is unspecified"), ProtocolCodec::Identity(payload_size) => Self::send_identity_payload(substream, payload_size, bytes).await, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, bytes.len()); - - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); - let mut offset = 0; - - while offset < len.len() { - offset += substream.write(&len[offset..]).await?; - } - - while bytes.has_remaining() { - let nwritten = substream.write(&bytes).await?; - bytes.advance(nwritten); - } - - substream.flush().await.map_err(From::from) - } + ProtocolCodec::UnsignedVarint(max_size) => + Self::send_unsigned_varint_payload(substream, bytes, max_size).await, }, #[cfg(feature = "quic")] SubstreamType::Quic(ref mut substream) => match self.codec { @@ -454,7 +449,7 @@ impl Substream { ProtocolCodec::UnsignedVarint(max_size) => { check_size!(max_size, bytes.len()); - let mut buffer = [0u8; 10]; + let mut buffer = unsigned_varint::encode::usize_buffer(); let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); let len = BytesMut::from(len); @@ -466,24 +461,8 @@ impl Substream { ProtocolCodec::Unspecified => panic!("codec is unspecified"), ProtocolCodec::Identity(payload_size) => Self::send_identity_payload(substream, payload_size, bytes).await, - ProtocolCodec::UnsignedVarint(max_size) => { - check_size!(max_size, bytes.len()); - - let mut buffer = [0u8; 10]; - let len = unsigned_varint::encode::usize(bytes.len(), &mut buffer); - let mut offset = 0; - - while offset < len.len() { - offset += substream.write(&len[offset..]).await?; - } - - while bytes.has_remaining() { - let nwritten = substream.write(&bytes).await?; - bytes.advance(nwritten); - } - - substream.flush().await.map_err(From::from) - } + ProtocolCodec::UnsignedVarint(max_size) => + Self::send_unsigned_varint_payload(substream, bytes, max_size).await, }, } } @@ -722,7 +701,7 @@ impl Sink for Substream { check_size!(max_size, item.len()); let len = { - let mut buffer = [0u8; 10]; + let mut buffer = unsigned_varint::encode::usize_buffer(); let len = unsigned_varint::encode::usize(item.len(), &mut buffer); BytesMut::from(len) };