diff --git a/tokio/src/io/util/write_buf.rs b/tokio/src/io/util/write_buf.rs index 879dabcb594..2eb0bb02e68 100644 --- a/tokio/src/io/util/write_buf.rs +++ b/tokio/src/io/util/write_buf.rs @@ -3,7 +3,7 @@ use crate::io::AsyncWrite; use bytes::Buf; use pin_project_lite::pin_project; use std::future::Future; -use std::io; +use std::io::{self, IoSlice}; use std::marker::PhantomPinned; use std::pin::Pin; use std::task::{ready, Context, Poll}; @@ -42,13 +42,22 @@ where type Output = io::Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + const MAX_VECTOR_ELEMENTS: usize = 64; + let me = self.project(); if !me.buf.has_remaining() { return Poll::Ready(Ok(0)); } - let n = ready!(Pin::new(me.writer).poll_write(cx, me.buf.chunk()))?; + let n = if me.writer.is_write_vectored() { + let mut slices = [IoSlice::new(&[]); MAX_VECTOR_ELEMENTS]; + let cnt = me.buf.chunks_vectored(&mut slices); + ready!(Pin::new(&mut *me.writer).poll_write_vectored(cx, &slices[..cnt]))? + } else { + ready!(Pin::new(&mut *me.writer).poll_write(cx, me.buf.chunk()))? + }; + me.buf.advance(n); Poll::Ready(Ok(n)) } diff --git a/tokio/tests/io_write_buf.rs b/tokio/tests/io_write_buf.rs index 9ae655b6ccd..8bd09ad62c9 100644 --- a/tokio/tests/io_write_buf.rs +++ b/tokio/tests/io_write_buf.rs @@ -54,3 +54,59 @@ async fn write_all() { assert_eq!(wr.cnt, 1); assert_eq!(buf.position(), 4); } + +#[tokio::test] +async fn write_buf_vectored() { + struct Wr { + buf: BytesMut, + cnt: usize, + } + + impl AsyncWrite for Wr { + fn poll_write( + self: Pin<&mut Self>, + _cx: &mut Context<'_>, + _buf: &[u8], + ) -> Poll> { + panic!("shouldn't be called") + } + + fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + Ok(()).into() + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let mut n = 0; + for buf in bufs { + self.buf.extend_from_slice(buf); + n += buf.len(); + } + self.cnt += 1; + Ok(n).into() + } + + fn is_write_vectored(&self) -> bool { + true + } + } + + let mut wr = Wr { + buf: BytesMut::with_capacity(64), + cnt: 0, + }; + + let mut buf = Cursor::new(&b"hello world"[..]); + + assert_ok!(wr.write_buf(&mut buf).await); + assert_eq!(wr.buf, b"hello world"[..]); + assert_eq!(wr.cnt, 1); + assert_eq!(buf.position(), 11); +}