Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/upgrade/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ type NameWrapIter<I> =
std::iter::Map<I, fn(<I as Iterator>::Item) -> NameWrap<<I as Iterator>::Item>>;

/// Wrapper type to expose an `AsRef<[u8]>` impl for all types implementing `ProtocolName`.
#[derive(Clone)]
struct NameWrap<N>(N);

impl<N: ProtocolName> AsRef<[u8]> for NameWrap<N> {
Expand Down
2 changes: 1 addition & 1 deletion core/src/upgrade/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl<T: AsRef<[u8]>> ProtocolName for T {
/// or both.
pub trait UpgradeInfo {
/// Opaque type representing a negotiable protocol.
type Info: ProtocolName;
type Info: ProtocolName + Clone;
/// Iterator returned by `protocol_info`.
type InfoIter: IntoIterator<Item = Self::Info>;

Expand Down
292 changes: 186 additions & 106 deletions misc/multistream-select/src/dialer_select.rs

Large diffs are not rendered by default.

161 changes: 65 additions & 96 deletions misc/multistream-select/src/length_delimited.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.

use bytes::Bytes;
use futures::{Async, Poll, Sink, StartSend, Stream};
use smallvec::SmallVec;
use std::{io::{Error as IoError, ErrorKind as IoErrorKind}, marker::PhantomData, u16};
use tokio_codec::FramedWrite;
use std::{io, u16};
use tokio_codec::{Encoder, FramedWrite};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::codec::UviBytes;
use unsigned_varint::decode;

/// `Stream` and `Sink` wrapping some `AsyncRead + AsyncWrite` object to read
/// and write unsigned-varint prefixed frames.
///
/// We purposely only support a frame length of under 64kiB. Frames mostly consist
/// in a short protocol name, which is highly unlikely to be more than 64kiB long.
pub struct LengthDelimited<I, S> {
pub struct LengthDelimited<R, C> {
// The inner socket where data is pulled from.
inner: FramedWrite<S, UviBytes>,
inner: FramedWrite<R, C>,
// Intermediary buffer where we put either the length of the next frame of data, or the frame
// of data itself before it is returned.
// Must always contain enough space to read data from `inner`.
internal_buffer: SmallVec<[u8; 64]>,
// Number of bytes within `internal_buffer` that contain valid data.
internal_buffer_pos: usize,
// State of the decoder.
state: State,
marker: PhantomData<I>,
state: State
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
Expand All @@ -52,24 +52,21 @@ enum State {
ReadingData { frame_len: u16 },
}

impl<I, S> LengthDelimited<I, S>
impl<R, C> LengthDelimited<R, C>
where
S: AsyncWrite
R: AsyncWrite,
C: Encoder
{
pub fn new(inner: S) -> LengthDelimited<I, S> {
let mut encoder = UviBytes::default();
encoder.set_max_len(usize::from(u16::MAX));

pub fn new(inner: R, codec: C) -> LengthDelimited<R, C> {
LengthDelimited {
inner: FramedWrite::new(inner, encoder),
inner: FramedWrite::new(inner, codec),
internal_buffer: {
let mut v = SmallVec::new();
v.push(0);
v
},
internal_buffer_pos: 0,
state: State::ReadingLength,
marker: PhantomData,
state: State::ReadingLength
}
}

Expand All @@ -85,20 +82,19 @@ where
/// the modifiers provided by the `futures` crate) will always leave the object in a state in
/// which `into_inner()` will not panic.
#[inline]
pub fn into_inner(self) -> S {
pub fn into_inner(self) -> R {
assert_eq!(self.state, State::ReadingLength);
assert_eq!(self.internal_buffer_pos, 0);
self.inner.into_inner()
}
}

impl<I, S> Stream for LengthDelimited<I, S>
impl<R, C> Stream for LengthDelimited<R, C>
where
S: AsyncRead,
I: for<'r> From<&'r [u8]>,
R: AsyncRead
{
type Item = I;
type Error = IoError;
type Item = Bytes;
type Error = io::Error;

fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
Expand All @@ -107,23 +103,21 @@ where

match self.state {
State::ReadingLength => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => {
// EOF
if self.internal_buffer_pos == 0 {
return Ok(Async::Ready(None));
} else {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
return Err(io::ErrorKind::UnexpectedEof.into());
}
}
Ok(n) => {
debug_assert_eq!(n, 1);
self.internal_buffer_pos += n;
}
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
Expand All @@ -136,7 +130,10 @@ where
if (*self.internal_buffer.last().unwrap_or(&0) & 0x80) == 0 {
// End of length prefix. Most of the time we will switch to reading data,
// but we need to handle a few corner cases first.
let frame_len = decode_length_prefix(&self.internal_buffer);
let (frame_len, _) = decode::u16(&self.internal_buffer).map_err(|e| {
log::debug!("invalid length prefix: {}", e);
io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
})?;

if frame_len >= 1 {
self.state = State::ReadingData { frame_len };
Expand All @@ -154,33 +151,22 @@ where
}
} else if self.internal_buffer_pos >= 2 {
// Length prefix is too long. See module doc for info about max frame len.
return Err(IoError::new(
IoErrorKind::InvalidData,
"frame length too long",
));
return Err(io::Error::new(io::ErrorKind::InvalidData, "frame length too long"));
} else {
// Prepare for next read.
self.internal_buffer.push(0);
}
}

State::ReadingData { frame_len } => {
match self.inner
.get_mut()
.read(&mut self.internal_buffer[self.internal_buffer_pos..])
{
Ok(0) => {
return Err(IoError::new(IoErrorKind::BrokenPipe, "unexpected eof"));
}
let slice = &mut self.internal_buffer[self.internal_buffer_pos..];
match self.inner.get_mut().read(slice) {
Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()),
Ok(n) => self.internal_buffer_pos += n,
Err(ref err) if err.kind() == IoErrorKind::WouldBlock => {
return Ok(Async::NotReady);
}
Err(err) => {
return Err(err);
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Ok(Async::NotReady)
}
Err(err) => return Err(err)
};

if self.internal_buffer_pos >= frame_len as usize {
// Finished reading the frame of data.
self.state = State::ReadingLength;
Expand All @@ -196,12 +182,13 @@ where
}
}

impl<I, S> Sink for LengthDelimited<I, S>
impl<R, C> Sink for LengthDelimited<R, C>
where
S: AsyncWrite
R: AsyncWrite,
C: Encoder
{
type SinkItem = <FramedWrite<S, UviBytes> as Sink>::SinkItem;
type SinkError = <FramedWrite<S, UviBytes> as Sink>::SinkError;
type SinkItem = <FramedWrite<R, C> as Sink>::SinkItem;
type SinkError = <FramedWrite<R, C> as Sink>::SinkError;

#[inline]
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
Expand All @@ -219,42 +206,25 @@ where
}
}

fn decode_length_prefix(buf: &[u8]) -> u16 {
debug_assert!(buf.len() <= 2);

let mut sum = 0u16;

for &byte in buf.iter().rev() {
let byte = byte & 0x7f;
sum <<= 7;
debug_assert!(sum.checked_add(u16::from(byte)).is_some());
sum += u16::from(byte);
}

sum
}

#[cfg(test)]
mod tests {
use futures::{Future, Stream};
use crate::length_delimited::LengthDelimited;
use std::io::Cursor;
use std::io::ErrorKind;
use std::io::{Cursor, ErrorKind};
use unsigned_varint::codec::UviBytes;

#[test]
fn basic_read() {
let data = vec![6, 9, 8, 7, 6, 5, 4];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
}

#[test]
fn basic_read_two() {
let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
}
Expand All @@ -266,8 +236,7 @@ mod tests {
let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
data.extend(frame.clone().into_iter());
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed
.into_future()
.map(|(m, _)| m)
Expand All @@ -281,24 +250,24 @@ mod tests {
fn packet_len_too_long() {
let mut data = vec![0x81, 0x81, 0x1];
data.extend((0..16513).map(|_| 0));
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed
.into_future()
.map(|(m, _)| m)
.map_err(|(err, _)| err)
.wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::InvalidData),
_ => panic!(),

if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::InvalidData)
} else {
panic!()
}
}

#[test]
fn empty_frames() {
let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait().unwrap();
assert_eq!(
recved,
Expand All @@ -315,36 +284,36 @@ mod tests {
#[test]
fn unexpected_eof_in_len() {
let data = vec![0x89];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}

#[test]
fn unexpected_eof_in_data() {
let data = vec![5];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}

#[test]
fn unexpected_eof_in_data2() {
let data = vec![5, 9, 8, 7];
let framed = LengthDelimited::<Vec<u8>, _>::new(Cursor::new(data));

let framed = LengthDelimited::new(Cursor::new(data), UviBytes::<Vec<_>>::default());
let recved = framed.collect().wait();
match recved {
Err(io_err) => assert_eq!(io_err.kind(), ErrorKind::BrokenPipe),
_ => panic!(),
if let Err(io_err) = recved {
assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
} else {
panic!()
}
}
}
Expand Down
Loading