Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/proto/h1/conn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ use std::marker::PhantomData;
use bytes::{Buf, Bytes};
use futures::{Async, Poll};
use http::{HeaderMap, Method, Version};
use http::header::{HeaderValue, CONNECTION};
use tokio_io::{AsyncRead, AsyncWrite};

use ::Chunk;
use proto::{BodyLength, DecodedLength, MessageHead};
use headers::connection_keep_alive;
use super::io::{Buffered};
use super::{EncodedBuf, Encode, Encoder, /*Decode,*/ Decoder, Http1Transaction, ParseContext};

Expand Down Expand Up @@ -438,12 +440,38 @@ where I: AsyncRead + AsyncWrite,
}
}

// Fix keep-alives when Connection: keep-alive header is not present
fn fix_keep_alive(&mut self, head: &mut MessageHead<T::Outgoing>) {
let outgoing_is_keep_alive = head
.headers
.get(CONNECTION)
.and_then(|value| Some(connection_keep_alive(value)))
.unwrap_or(false);

if !outgoing_is_keep_alive {
match head.version {
// If response is version 1.0 and keep-alive is not present in the response,
// disable keep-alive so the server closes the connection
Version::HTTP_10 => self.state.disable_keep_alive(),
// If response is version 1.1 and keep-alive is wanted, add
// Connection: keep-alive header when not present
Version::HTTP_11 => if self.state.wants_keep_alive() {
head.headers
.insert(CONNECTION, HeaderValue::from_static("keep-alive"));
},
_ => (),
}
}
}

// If we know the remote speaks an older version, we try to fix up any messages
// to work with our older peer.
fn enforce_version(&mut self, head: &mut MessageHead<T::Outgoing>) {

match self.state.version {
Version::HTTP_10 => {
// Fixes response or connection when keep-alive header is not present
self.fix_keep_alive(head);
// If the remote only knows HTTP/1.0, we should force ourselves
// to do only speak HTTP/1.0 as well.
head.version = Version::HTTP_10;
Expand Down
70 changes: 69 additions & 1 deletion tests/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use tokio::reactor::Handle;
use tokio_io::{AsyncRead, AsyncWrite};


use hyper::{Body, Request, Response, StatusCode};
use hyper::{Body, Request, Response, StatusCode, Version};
use hyper::client::Client;
use hyper::server::conn::Http;
use hyper::server::Server;
Expand Down Expand Up @@ -637,6 +637,7 @@ fn keep_alive() {
fn http_10_keep_alive() {
let foo_bar = b"foo bar baz";
let server = serve();
// Response version 1.1 with no keep-alive header will downgrade to 1.0 when served
server.reply()
.header("content-length", foo_bar.len().to_string())
.body(foo_bar);
Expand All @@ -658,6 +659,10 @@ fn http_10_keep_alive() {
}
}

// Connection: keep-alive header should be added when downgrading to a 1.0 response
let response = String::from_utf8(buf.to_vec()).unwrap();
response.contains("Connection: keep-alive\r\n");

// try again!

let quux = b"zar quux";
Expand All @@ -682,6 +687,69 @@ fn http_10_keep_alive() {
}
}

#[test]
fn http_10_close_on_no_ka() {
let foo_bar = b"foo bar baz";
let server = serve();

// A server response with version 1.0 but no keep-alive header
server
.reply()
.version(Version::HTTP_10)
.header("content-length", foo_bar.len().to_string())
.body(foo_bar);
let mut req = connect(server.addr());

// The client request with version 1.0 that may have the keep-alive header
req.write_all(
b"\
GET / HTTP/1.0\r\n\
Host: example.domain\r\n\
Connection: keep-alive\r\n\
\r\n\
",
).expect("writing 1");

let mut buf = [0; 1024 * 8];
loop {
let n = req.read(&mut buf[..]).expect("reading 1");
if n < buf.len() {
if &buf[n - foo_bar.len()..n] == foo_bar {
break;
} else {
}
}
}

// try again!

let quux = b"zar quux";
server
.reply()
.header("content-length", quux.len().to_string())
.body(quux);

// the write can possibly succeed, since it fills the kernel buffer on the first write
let _ = req.write_all(
b"\
GET /quux HTTP/1.1\r\n\
Host: example.domain\r\n\
Connection: close\r\n\
\r\n\
",
);

let mut buf = [0; 1024 * 8];
match req.read(&mut buf[..]) {
// Ok(0) means EOF, so a proper shutdown
// Err(_) could mean ConnReset or something, also fine
Ok(0) | Err(_) => {}
Ok(n) => {
panic!("read {} bytes on a disabled keep-alive socket", n);
}
}
}

#[test]
fn disable_keep_alive() {
let foo_bar = b"foo bar baz";
Expand Down