Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied
<axum::http::request::Parts as FromRequestParts<S>>
<Uri as FromRequestParts<S>>
<Version as FromRequestParts<S>>
<ConnectInfo<T> as FromRequestParts<S>>
<Extensions as FromRequestParts<S>>
<ConnectInfo<T> as FromRequestParts<S>>
and $N others
= note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>`
note: required by a bound in `__axum_macros_check_handler_0_from_request_check`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ error[E0277]: the trait bound `String: FromRequestParts<S>` is not satisfied
<axum::http::request::Parts as FromRequestParts<S>>
<Uri as FromRequestParts<S>>
<Version as FromRequestParts<S>>
<ConnectInfo<T> as FromRequestParts<S>>
<Extensions as FromRequestParts<S>>
and $N others
6 changes: 3 additions & 3 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ matched-path = []
multipart = ["dep:multer"]
original-uri = []
query = ["dep:serde_urlencoded"]
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make"]
tokio = ["dep:hyper-util", "dep:tokio", "tokio/net", "tokio/rt", "tower/make", "tokio/macros"]
tower-log = ["tower/log"]
tracing = ["dep:tracing", "axum-core/tracing"]
ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"]
Expand Down Expand Up @@ -53,8 +53,8 @@ tower-service = "0.3"
# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.4.0", optional = true }
base64 = { version = "0.21.0", optional = true }
hyper = { version = "1.0.0", optional = true }
hyper-util = { version = "0.1.1", features = ["tokio", "server", "server-auto"], optional = true }
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"], optional = true }
multer = { version = "3.0.0", optional = true }
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
serde_path_to_error = { version = "0.1.8", optional = true }
Expand Down
12 changes: 12 additions & 0 deletions axum/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,25 @@ macro_rules! all_the_tuples {
};
}

#[cfg(feature = "tracing")]
macro_rules! trace {
($($tt:tt)*) => {
tracing::trace!($($tt)*)
}
}

#[cfg(feature = "tracing")]
macro_rules! error {
($($tt:tt)*) => {
tracing::error!($($tt)*)
};
}

#[cfg(not(feature = "tracing"))]
macro_rules! trace {
($($tt:tt)*) => {};
}

#[cfg(not(feature = "tracing"))]
macro_rules! error {
($($tt:tt)*) => {};
Expand Down
229 changes: 199 additions & 30 deletions axum/src/serve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@

use std::{
convert::Infallible,
future::{Future, IntoFuture},
fmt::Debug,
future::{poll_fn, Future, IntoFuture},
io,
marker::PhantomData,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};

use axum_core::{body::Body, extract::Request, response::Response};
use futures_util::future::poll_fn;
use futures_util::{pin_mut, FutureExt};
use hyper::body::Incoming;
use hyper_util::{
rt::{TokioExecutor, TokioIo},
server::conn::auto::Builder,
};
use pin_project_lite::pin_project;
use tokio::net::{TcpListener, TcpStream};
use tokio::{
net::{TcpListener, TcpStream},
sync::watch,
};
use tower::util::{Oneshot, ServiceExt};
use tower_service::Service;

Expand Down Expand Up @@ -110,9 +115,25 @@ pub struct Serve<M, S> {
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> std::fmt::Debug for Serve<M, S>
impl<M, S> Serve<M, S> {
/// TODO
pub fn with_graceful_shutdown<F>(self, signal: F) -> WithGracefulShutdown<M, S, F>
where
F: Future<Output = ()> + Send + 'static,
{
WithGracefulShutdown {
tcp_listener: self.tcp_listener,
make_service: self.make_service,
signal,
_marker: PhantomData,
}
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S> Debug for Serve<M, S>
where
M: std::fmt::Debug,
M: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
Expand Down Expand Up @@ -148,30 +169,9 @@ where
} = self;

loop {
let (tcp_stream, remote_addr) = match tcp_listener.accept().await {
Ok(conn) => conn,
Err(e) => {
// Connection errors can be ignored directly, continue
// by accepting the next request.
if is_connection_error(&e) {
continue;
}

// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
continue;
}
let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await {
Some(conn) => conn,
None => continue,
};
let tcp_stream = TokioIo::new(tcp_stream);

Expand All @@ -191,7 +191,7 @@ where
service: tower_service,
};

tokio::task::spawn(async move {
tokio::spawn(async move {
match Builder::new(TokioExecutor::new())
// upgrades needed for websockets
.serve_connection_with_upgrades(tcp_stream, hyper_service)
Expand All @@ -212,6 +212,149 @@ where
}
}

/// Serve future with graceful shutdown enabled.
#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
pub struct WithGracefulShutdown<M, S, F> {
tcp_listener: TcpListener,
make_service: M,
signal: F,
_marker: PhantomData<S>,
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> Debug for WithGracefulShutdown<M, S, F>
where
M: Debug,
S: Debug,
F: Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let Self {
tcp_listener,
make_service,
signal,
_marker: _,
} = self;

f.debug_struct("WithGracefulShutdown")
.field("tcp_listener", tcp_listener)
.field("make_service", make_service)
.field("signal", signal)
.finish()
}
}

#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
impl<M, S, F> IntoFuture for WithGracefulShutdown<M, S, F>
where
M: for<'a> Service<IncomingStream<'a>, Error = Infallible, Response = S> + Send + 'static,
for<'a> <M as Service<IncomingStream<'a>>>::Future: Send,
S: Service<Request, Response = Response, Error = Infallible> + Clone + Send + 'static,
S::Future: Send,
F: Future<Output = ()> + Send + 'static,
{
type Output = io::Result<()>;
type IntoFuture = private::ServeFuture;

fn into_future(self) -> Self::IntoFuture {
let Self {
tcp_listener,
mut make_service,
signal,
_marker: _,
} = self;

let (signal_tx, signal_rx) = watch::channel(());
let signal_tx = Arc::new(signal_tx);
tokio::spawn(async move {
signal.await;
trace!("received graceful shutdown signal. Telling tasks to shutdown");
drop(signal_rx);
});

let (close_tx, close_rx) = watch::channel(());

private::ServeFuture(Box::pin(async move {
loop {
let (tcp_stream, remote_addr) = tokio::select! {
conn = tcp_accept(&tcp_listener) => {
match conn {
Some(conn) => conn,
None => continue,
}
}
_ = signal_tx.closed() => {
trace!("signal received, not accepting new connections");
break;
}
};
let tcp_stream = TokioIo::new(tcp_stream);

trace!("connection {remote_addr} accepted");

poll_fn(|cx| make_service.poll_ready(cx))
.await
.unwrap_or_else(|err| match err {});

let tower_service = make_service
.call(IncomingStream {
tcp_stream: &tcp_stream,
remote_addr,
})
.await
.unwrap_or_else(|err| match err {});

let hyper_service = TowerToHyperService {
service: tower_service,
};

let signal_tx = Arc::clone(&signal_tx);

let close_rx = close_rx.clone();

tokio::spawn(async move {
let builder = Builder::new(TokioExecutor::new());
let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service);
pin_mut!(conn);

let signal_closed = signal_tx.closed().fuse();
pin_mut!(signal_closed);

loop {
tokio::select! {
result = conn.as_mut() => {
if let Err(_err) = result {
trace!("failed to serve connection: {_err:#}");
}
break;
}
_ = &mut signal_closed => {
trace!("signal received in task, starting graceful shutdown");
conn.as_mut().graceful_shutdown();
}
}
}

trace!("connection {remote_addr} closed");

drop(close_rx);
});
}

drop(close_rx);
drop(tcp_listener);

trace!(
"waiting for {} task(s) to finish",
close_tx.receiver_count()
);
close_tx.closed().await;

Ok(())
}))
}
}

fn is_connection_error(e: &io::Error) -> bool {
matches!(
e.kind(),
Expand All @@ -221,6 +364,32 @@ fn is_connection_error(e: &io::Error) -> bool {
)
}

async fn tcp_accept(listener: &TcpListener) -> Option<(TcpStream, SocketAddr)> {
match listener.accept().await {
Ok(conn) => Some(conn),
Err(e) => {
if is_connection_error(&e) {
return None;
}

// [From `hyper::Server` in 0.14](https://github.com/hyperium/hyper/blob/v0.14.27/src/server/tcp.rs#L186)
//
// > A possible scenario is that the process has hit the max open files
// > allowed, and so trying to accept a new connection will fail with
// > `EMFILE`. In some cases, it's preferable to just wait for some time, if
// > the application will likely close some files (or connections), and try
// > to accept the connection again. If this option is `true`, the error
// > will be logged at the `error` level, since it is still a big deal,
// > and then the listener will sleep for 1 second.
//
// hyper allowed customizing this but axum does not.
error!("accept error: {e}");
tokio::time::sleep(Duration::from_secs(1)).await;
None
}
}
}

mod private {
use std::{
future::Future,
Expand Down
2 changes: 1 addition & 1 deletion examples/graceful-shutdown/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ edition = "2021"
publish = false

[dependencies]
axum = { path = "../../axum" }
axum = { path = "../../axum", features = ["tracing"] }
hyper = { version = "1.0", features = [] }
hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] }
tokio = { version = "1.0", features = ["full"] }
Expand Down
Loading