Skip to content

Commit 9500fb2

Browse files
Add test for poll-based API
1 parent c1224cf commit 9500fb2

File tree

1 file changed

+236
-0
lines changed

1 file changed

+236
-0
lines changed

tests/poll_api.rs

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
use std::future::Future;
2+
use std::io;
3+
use std::net::{Ipv4Addr, SocketAddrV4, SocketAddr};
4+
use std::pin::Pin;
5+
use std::task::{Context, Poll};
6+
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, FutureExt, StreamExt};
7+
use futures::future::BoxFuture;
8+
use futures::stream::FuturesUnordered;
9+
use quickcheck::{Arbitrary, Gen, QuickCheck};
10+
use tokio::net::{TcpListener, TcpStream};
11+
use tokio::runtime::Runtime;
12+
use tokio_util::compat::TokioAsyncReadCompatExt;
13+
use yamux::{Connection, Mode, WindowUpdateMode};
14+
15+
#[test]
16+
fn prop_config_send_recv_multi() {
17+
let _ = env_logger::try_init();
18+
19+
fn prop(msgs: Vec<Msg>, cfg1: TestConfig, cfg2: TestConfig) {
20+
Runtime::new().unwrap().block_on(async move {
21+
let num_messagses = msgs.len();
22+
23+
let (listener, address) = bind().await.expect("bind");
24+
25+
let server = async {
26+
let socket = listener.accept().await.expect("accept").0.compat();
27+
let connection = Connection::new(socket, cfg1.0, Mode::Server);
28+
29+
EchoServer::new(connection).await
30+
};
31+
32+
let client = async {
33+
let socket = TcpStream::connect(address).await.expect("connect").compat();
34+
let connection = Connection::new(socket, cfg2.0, Mode::Client);
35+
36+
MessageSender::new(connection, msgs).await
37+
};
38+
39+
let (server_processed, client_processed) = futures::future::try_join(server, client).await.unwrap();
40+
41+
assert_eq!(server_processed, num_messagses);
42+
assert_eq!(client_processed, num_messagses);
43+
})
44+
}
45+
46+
QuickCheck::new()
47+
.tests(10)
48+
.quickcheck(prop as fn(_, _, _) -> _)
49+
}
50+
51+
#[derive(Clone, Debug)]
52+
struct Msg(Vec<u8>);
53+
54+
impl Arbitrary for Msg {
55+
fn arbitrary(g: &mut Gen) -> Msg {
56+
let mut msg = Msg(Arbitrary::arbitrary(g));
57+
if msg.0.is_empty() {
58+
msg.0.push(Arbitrary::arbitrary(g));
59+
}
60+
61+
msg
62+
}
63+
64+
fn shrink(&self) -> Box<dyn Iterator<Item=Self>> {
65+
Box::new(self.0.shrink().filter(|v| !v.is_empty()).map(|v| Msg(v)))
66+
}
67+
}
68+
69+
#[derive(Clone, Debug)]
70+
struct TestConfig(yamux::Config);
71+
72+
impl Arbitrary for TestConfig {
73+
fn arbitrary(g: &mut Gen) -> Self {
74+
let mut c = yamux::Config::default();
75+
c.set_window_update_mode(if bool::arbitrary(g) {
76+
WindowUpdateMode::OnRead
77+
} else {
78+
WindowUpdateMode::OnReceive
79+
});
80+
c.set_read_after_close(Arbitrary::arbitrary(g));
81+
c.set_receive_window(256 * 1024 + u32::arbitrary(g) % (768 * 1024));
82+
TestConfig(c)
83+
}
84+
}
85+
86+
async fn bind() -> io::Result<(TcpListener, SocketAddr)> {
87+
let i = Ipv4Addr::new(127, 0, 0, 1);
88+
let s = SocketAddr::V4(SocketAddrV4::new(i, 0));
89+
let l = TcpListener::bind(&s).await?;
90+
let a = l.local_addr()?;
91+
Ok((l, a))
92+
}
93+
94+
struct EchoServer<T> {
95+
connection: Connection<T>,
96+
worker_streams: FuturesUnordered<BoxFuture<'static, yamux::Result<()>>>,
97+
streams_processed: usize
98+
}
99+
100+
impl<T> EchoServer<T> {
101+
fn new(connection: Connection<T>) -> Self {
102+
Self {
103+
connection,
104+
worker_streams: FuturesUnordered::default(),
105+
streams_processed: 0
106+
}
107+
}
108+
}
109+
110+
impl<T> Future for EchoServer<T> where T: AsyncRead + AsyncWrite + Unpin {
111+
type Output = yamux::Result<usize>;
112+
113+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
114+
let this = self.get_mut();
115+
116+
loop {
117+
match this.worker_streams.poll_next_unpin(cx) {
118+
Poll::Ready(Some(Ok(()))) => {
119+
this.streams_processed += 1;
120+
continue;
121+
}
122+
Poll::Ready(Some(Err(e))) => {
123+
eprintln!("A stream failed: {}", e);
124+
continue;
125+
}
126+
Poll::Ready(None) | Poll::Pending => {}
127+
}
128+
129+
match this.connection.poll_next_inbound(cx)? {
130+
Poll::Ready(Some(mut stream)) => {
131+
this.worker_streams.push(async move {
132+
{
133+
let (mut r, mut w) = AsyncReadExt::split(&mut stream);
134+
futures::io::copy(&mut r, &mut w).await?;
135+
}
136+
stream.close().await?;
137+
Ok(())
138+
}.boxed());
139+
continue;
140+
}
141+
Poll::Ready(None) => return Poll::Ready(Ok(this.streams_processed)),
142+
Poll::Pending => {}
143+
}
144+
145+
return Poll::Pending;
146+
}
147+
}
148+
}
149+
150+
struct MessageSender<T> {
151+
connection: Connection<T>,
152+
pending_messages: Vec<Msg>,
153+
worker_streams: FuturesUnordered<BoxFuture<'static, ()>>,
154+
streams_processed: usize
155+
}
156+
157+
impl<T> MessageSender<T> {
158+
fn new(connection: Connection<T>, messages: Vec<Msg>) -> Self {
159+
Self {
160+
connection,
161+
pending_messages: messages,
162+
worker_streams: FuturesUnordered::default(),
163+
streams_processed: 0
164+
}
165+
}
166+
}
167+
168+
impl<T> Future for MessageSender<T> where T: AsyncRead + AsyncWrite + Unpin {
169+
type Output = yamux::Result<usize>;
170+
171+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172+
let this = self.get_mut();
173+
174+
loop {
175+
if this.pending_messages.is_empty() && this.worker_streams.is_empty() {
176+
futures::ready!(this.connection.poll_close(cx)?);
177+
178+
return Poll::Ready(Ok(this.streams_processed));
179+
}
180+
181+
match this.worker_streams.poll_next_unpin(cx) {
182+
Poll::Ready(Some(())) => {
183+
this.streams_processed += 1;
184+
continue;
185+
}
186+
Poll::Ready(None) | Poll::Pending => {}
187+
}
188+
189+
if let Some(Msg(message)) = this.pending_messages.pop() {
190+
match this.connection.poll_new_outbound(cx)? {
191+
Poll::Ready(stream) => {
192+
this.worker_streams.push(async move {
193+
let id = stream.id();
194+
let len = message.len();
195+
196+
let (mut reader, mut writer) = AsyncReadExt::split(stream);
197+
198+
let write_fut = async {
199+
writer.write_all(&message).await.unwrap();
200+
log::debug!("C: {}: sent {} bytes", id, len);
201+
writer.close().await.unwrap();
202+
};
203+
204+
let mut received = Vec::new();
205+
let read_fut = async {
206+
reader.read_to_end(&mut received).await.unwrap();
207+
log::debug!("C: {}: received {} bytes", id, received.len());
208+
};
209+
210+
futures::future::join(write_fut, read_fut).await;
211+
212+
assert_eq!(message, received)
213+
}.boxed());
214+
continue;
215+
}
216+
Poll::Pending => {
217+
this.pending_messages.push(Msg(message));
218+
}
219+
}
220+
}
221+
222+
match this.connection.poll_next_inbound(cx)? {
223+
Poll::Ready(Some(stream)) => {
224+
drop(stream);
225+
panic!("Did not expect remote to open a stream");
226+
}
227+
Poll::Ready(None) => {
228+
panic!("Did not expect remote to close the connection");
229+
},
230+
Poll::Pending => {}
231+
}
232+
233+
return Poll::Pending;
234+
}
235+
}
236+
}

0 commit comments

Comments
 (0)