Skip to content

Commit d039c35

Browse files
authored
Add workaround for inference errors after removing B type param (#1835)
1 parent 161bb60 commit d039c35

5 files changed

Lines changed: 115 additions & 11 deletions

File tree

axum/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1111
- **breaking:** Change `sse::Event::json_data` to use `axum_core::Error` as its error type ([#1762])
1212
- **breaking:** Rename `DefaultOnFailedUpdgrade` to `DefaultOnFailedUpgrade` ([#1664])
1313
- **breaking:** Rename `OnFailedUpdgrade` to `OnFailedUpgrade` ([#1664])
14+
- **added:** Add `Router::as_service` to workaround type inference issues when
15+
calling `ServiceExt` methods on a `Router` ([#1835])
1416

1517
[#1762]: https://github.com/tokio-rs/axum/pull/1762
1618
[#1664]: https://github.com/tokio-rs/axum/pull/1664
19+
[#1835]: https://github.com/tokio-rs/axum/pull/1835
1720

1821
# 0.6.9 (24. February, 2023)
1922

axum/src/routing/mod.rs

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ use std::{
1616
collections::HashMap,
1717
convert::Infallible,
1818
fmt,
19+
marker::PhantomData,
1920
sync::Arc,
2021
task::{Context, Poll},
2122
};
@@ -498,6 +499,66 @@ where
498499
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
499500
}
500501
}
502+
503+
/// Convert the router into a [`Service`] with a fixed request body type, to aid type
504+
/// inference.
505+
///
506+
/// In some cases when calling methods from [`tower::ServiceExt`] on a [`Router`] you might get
507+
/// type inference errors along the lines of
508+
///
509+
/// ```not_rust
510+
/// let response = router.ready().await?.call(request).await?;
511+
/// ^^^^^ cannot infer type for type parameter `B`
512+
/// ```
513+
///
514+
/// This happens because `Router` implements [`Service`] with `impl<B> Service<Request<B>> for Router<()>`.
515+
///
516+
/// For example:
517+
///
518+
/// ```compile_fail
519+
/// use axum::{
520+
/// Router,
521+
/// routing::get,
522+
/// http::Request,
523+
/// body::Body,
524+
/// };
525+
/// use tower::{Service, ServiceExt};
526+
///
527+
/// # async fn async_main() -> Result<(), Box<dyn std::error::Error>> {
528+
/// let mut router = Router::new().route("/", get(|| async {}));
529+
/// let request = Request::new(Body::empty());
530+
/// let response = router.ready().await?.call(request).await?;
531+
/// # Ok(())
532+
/// # }
533+
/// ```
534+
///
535+
/// Calling `Router::as_service` fixes that:
536+
///
537+
/// ```
538+
/// use axum::{
539+
/// Router,
540+
/// routing::get,
541+
/// http::Request,
542+
/// body::Body,
543+
/// };
544+
/// use tower::{Service, ServiceExt};
545+
///
546+
/// # async fn async_main() -> Result<(), Box<dyn std::error::Error>> {
547+
/// let mut router = Router::new().route("/", get(|| async {}));
548+
/// let request = Request::new(Body::empty());
549+
/// let response = router.as_service().ready().await?.call(request).await?;
550+
/// # Ok(())
551+
/// # }
552+
/// ```
553+
///
554+
/// This is mainly used when calling `Router` in tests. It shouldn't be necessary when running
555+
/// the `Router` normally via [`Router::into_make_service`].
556+
pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S> {
557+
RouterAsService {
558+
router: self,
559+
_marker: PhantomData,
560+
}
561+
}
501562
}
502563

503564
impl Router {
@@ -560,6 +621,45 @@ where
560621
}
561622
}
562623

624+
/// A [`Router`] converted into a service with a fixed body type.
625+
///
626+
/// See [`Router::as_service`] for more details.
627+
pub struct RouterAsService<'a, B, S = ()> {
628+
router: &'a mut Router<S>,
629+
_marker: PhantomData<B>,
630+
}
631+
632+
impl<'a, B> Service<Request<B>> for RouterAsService<'a, B, ()>
633+
where
634+
B: HttpBody<Data = bytes::Bytes> + Send + 'static,
635+
B::Error: Into<axum_core::BoxError>,
636+
{
637+
type Response = Response;
638+
type Error = Infallible;
639+
type Future = RouteFuture<Infallible>;
640+
641+
#[inline]
642+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
643+
<Router as Service<Request<B>>>::poll_ready(self.router, cx)
644+
}
645+
646+
#[inline]
647+
fn call(&mut self, req: Request<B>) -> Self::Future {
648+
self.router.call(req)
649+
}
650+
}
651+
652+
impl<'a, B, S> fmt::Debug for RouterAsService<'a, B, S>
653+
where
654+
S: fmt::Debug,
655+
{
656+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
657+
f.debug_struct("RouterAsService")
658+
.field("router", &self.router)
659+
.finish()
660+
}
661+
}
662+
563663
/// Wrapper around `matchit::Router` that supports merging two `Router`s.
564664
#[derive(Clone, Default)]
565665
struct Node {

examples/hyper-1-0/src/main.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@ use axum::{routing::get, Router};
88
use std::net::SocketAddr;
99
use tokio::net::TcpListener;
1010
use tower_http::trace::TraceLayer;
11-
use tower_hyper_http_body_compat::{
12-
HttpBody1ToHttpBody04, TowerService03HttpServiceAsHyper1HttpService,
13-
};
11+
use tower_hyper_http_body_compat::TowerService03HttpServiceAsHyper1HttpService;
1412
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
1513

1614
// this is hyper 1.0
17-
use hyper::{body::Incoming, server::conn::http1};
15+
use hyper::server::conn::http1;
1816

1917
#[tokio::main]
2018
async fn main() {
@@ -26,8 +24,7 @@ async fn main() {
2624
.with(tracing_subscriber::fmt::layer())
2725
.init();
2826

29-
// you have to use `HttpBody1ToHttpBody04<Incoming>` as the second type parameter to `Router`
30-
let app: Router<_, HttpBody1ToHttpBody04<Incoming>> = Router::new()
27+
let app = Router::new()
3128
.route("/", get(|| async { "Hello, World!" }))
3229
// we can still add regular tower middleware
3330
.layer(TraceLayer::new_for_http());

examples/static-file-server/src/main.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,7 @@ fn using_serve_dir_with_handler_as_service() -> Router {
7272
}
7373

7474
// you can convert handler function to service
75-
let service = tower::ServiceExt::<Request<Body>>::map_err(
76-
handle_404.into_service(),
77-
|err| -> std::io::Error { match err {} },
78-
);
75+
let service = handle_404.into_service();
7976

8077
let serve_dir = ServeDir::new("assets").not_found_service(service);
8178

examples/testing/src/main.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,14 @@ mod tests {
196196
.uri("/requires-connect-into")
197197
.body(Body::empty())
198198
.unwrap();
199-
let response = app.ready().await.unwrap().call(request).await.unwrap();
199+
let response = app
200+
.as_service()
201+
.ready()
202+
.await
203+
.unwrap()
204+
.call(request)
205+
.await
206+
.unwrap();
200207
assert_eq!(response.status(), StatusCode::OK);
201208
}
202209
}

0 commit comments

Comments
 (0)