diff --git a/bitreq/src/client.rs b/bitreq/src/client.rs index b5de6f2f..0c97ccbf 100644 --- a/bitreq/src/client.rs +++ b/bitreq/src/client.rs @@ -1,14 +1,17 @@ -//! Connection pooling client for HTTP requests. +//! Connection pooling [`Client`] for HTTP requests. //! -//! The `Client` caches connections to avoid repeated TCP handshakes and TLS negotiations. +//! The [`Client`] caches connections to avoid repeated TCP handshakes and TLS negotiations. //! -//! Due to std limitations, `Client` currently only supports async requests. - -#![cfg(feature = "async")] +//! When the `async` feature is enabled, the client uses async connections via `tokio`. +//! Otherwise a blocking client backed by `std::net::TcpStream` is provided. use std::collections::{hash_map, HashMap, VecDeque}; use std::sync::{Arc, Mutex}; +// --------------------------------------------------------------------------- +// Async Client (feature = "async") +// --------------------------------------------------------------------------- +#[cfg(feature = "async")] use crate::connection::AsyncConnection; use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest}; use crate::{Error, Request, Response}; @@ -30,17 +33,20 @@ use crate::{Error, Request, Response}; /// .await; /// # } /// ``` +#[cfg(feature = "async")] #[derive(Clone)] pub struct Client { - r#async: Arc>>, + r#async: Arc>, } -struct ClientImpl { - connections: HashMap>, +#[cfg(feature = "async")] +struct AsyncClientState { + connections: HashMap>, lru_order: VecDeque, capacity: usize, } +#[cfg(feature = "async")] impl Client { /// Creates a new `Client` with the specified connection cache capacity. /// @@ -50,7 +56,7 @@ impl Client { /// reached, the least recently used connection is evicted. pub fn new(capacity: usize) -> Self { Client { - r#async: Arc::new(Mutex::new(ClientImpl { + r#async: Arc::new(Mutex::new(AsyncClientState { connections: HashMap::new(), lru_order: VecDeque::new(), capacity, @@ -98,7 +104,8 @@ impl Client { } } -/// Extension trait for `Request` to use with `Client`. +/// Extension trait for [`Request`] to use with [`Client`]. +#[cfg(feature = "async")] pub trait RequestExt { /// Sends this request asynchronously using the provided client's connection pool. fn send_async_with_client( @@ -107,6 +114,7 @@ pub trait RequestExt { ) -> impl std::future::Future>; } +#[cfg(feature = "async")] impl RequestExt for Request { fn send_async_with_client( self, @@ -115,3 +123,177 @@ impl RequestExt for Request { client.send_async(self) } } + +// --------------------------------------------------------------------------- +// Blocking Client (no "async" feature) +// --------------------------------------------------------------------------- + +#[cfg(not(feature = "async"))] +use core::time::Duration; +#[cfg(not(feature = "async"))] +use std::time::Instant; + +#[cfg(not(feature = "async"))] +use crate::connection::{Connection, HttpStream}; +#[cfg(not(feature = "async"))] +use crate::Method; + +#[cfg(not(feature = "async"))] +struct PoolEntry { + stream: HttpStream, + expires_at: Instant, +} + +/// A client that caches connections for reuse. +/// +/// The client maintains a pool of up to `capacity` connections, evicting +/// the least recently used connection when the cache is full. A cached +/// connection is reused when the server indicated `Connection: keep-alive` +/// and the keep-alive timeout has not yet expired. +/// +/// # Example +/// +/// ```no_run +/// # fn main() -> Result<(), bitreq::Error> { +/// use bitreq::{Client, RequestExt}; +/// +/// let client = Client::new(10); // Cache up to 10 connections +/// let response = bitreq::get("http://example.com") +/// .send_with_client(&client)?; +/// # Ok(()) } +/// ``` +#[cfg(not(feature = "async"))] +#[derive(Clone)] +pub struct Client { + state: Arc>, +} + +#[cfg(not(feature = "async"))] +struct BlockingClientState { + connections: HashMap, + lru_order: VecDeque, + capacity: usize, +} + +#[cfg(not(feature = "async"))] +impl Client { + /// Creates a new `Client` with the specified connection cache capacity. + /// + /// # Arguments + /// + /// * `capacity` - Maximum number of cached connections. When this limit is + /// reached, the least recently used connection is evicted. + pub fn new(capacity: usize) -> Self { + Client { + state: Arc::new(Mutex::new(BlockingClientState { + connections: HashMap::new(), + lru_order: VecDeque::new(), + capacity, + })), + } + } + + /// Sends a request using a cached connection if available. + pub fn send(&self, request: Request) -> Result { + let parsed_request = ParsedRequest::new(request)?; + self.send_inner(parsed_request) + } + + fn send_inner(&self, mut request: ParsedRequest) -> Result { + loop { + let key: ConnectionKey = request.connection_params().into(); + + // Get cached stream or create new connection + let connection = match self.take_stream(&key) { + Some(stream) => Connection::from_stream(stream), + None => Connection::new(request.connection_params(), request.timeout_at)?, + }; + + let (response, stream, req) = connection.send_for_pool(request)?; + request = req; + + // Cache stream if keep-alive, with expiry + if let Some(stream) = stream { + let expires_at = Self::parse_keep_alive_timeout(response.headers.get("keep-alive")); + self.put_stream(key, stream, expires_at); + } + + // Handle redirects + match response.status_code { + 301 | 302 | 303 | 307 => { + let location = response + .headers + .get("location") + .ok_or(Error::RedirectLocationMissing)? + .clone(); + request.redirect_to(&location)?; + if response.status_code == 303 { + match request.config.method { + Method::Post | Method::Put | Method::Delete => { + request.config.method = Method::Get; + } + _ => {} + } + } + continue; + } + _ => return Ok(response), + } + } + } + + fn take_stream(&self, key: &ConnectionKey) -> Option { + let mut state = self.state.lock().unwrap(); + if let Some(entry) = state.connections.remove(key) { + // Remove from LRU order + if let Some(pos) = state.lru_order.iter().position(|k| k == key) { + state.lru_order.remove(pos); + } + if entry.expires_at > Instant::now() { + return Some(entry.stream); + } + } + None + } + + fn put_stream(&self, key: ConnectionKey, stream: HttpStream, expires_at: Instant) { + let mut state = self.state.lock().unwrap(); + if let hash_map::Entry::Vacant(entry) = state.connections.entry(key.clone()) { + entry.insert(PoolEntry { stream, expires_at }); + state.lru_order.push_back(key); + if state.connections.len() > state.capacity { + if let Some(oldest_key) = state.lru_order.pop_front() { + state.connections.remove(&oldest_key); + } + } + } + } + + fn parse_keep_alive_timeout(keep_alive_header: Option<&String>) -> Instant { + let default_timeout = Instant::now() + Duration::from_secs(60); + if let Some(header) = keep_alive_header { + for param in header.split(',') { + if let Some((k, v)) = param.trim().split_once('=') { + if k.trim() == "timeout" { + if let Ok(secs) = v.parse::() { + return Instant::now() + Duration::from_secs(secs.saturating_sub(1)); + } + } + } + } + } + default_timeout + } +} + +/// Extension trait for [`Request`] to use with [`Client`]. +#[cfg(not(feature = "async"))] +pub trait RequestExt { + /// Sends this request using the provided client's connection pool. + fn send_with_client(self, client: &Client) -> Result; +} + +#[cfg(not(feature = "async"))] +impl RequestExt for Request { + fn send_with_client(self, client: &Client) -> Result { client.send(self) } +} diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index f8b98c13..22eff0d0 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -23,9 +23,7 @@ use tokio::net::TcpStream as AsyncTcpStream; use tokio::sync::Mutex as AsyncMutex; use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest}; -#[cfg(feature = "async")] -use crate::Response; -use crate::{Error, Method, ResponseLazy}; +use crate::{Error, Method, Response, ResponseLazy}; type UnsecuredStream = TcpStream; @@ -51,6 +49,18 @@ impl HttpStream { pub(crate) fn create_buffer(buffer: Vec) -> HttpStream { HttpStream::Buffer(std::io::Cursor::new(buffer)) } + + /// Updates the timeout deadline used for read/write operations on this stream. + #[cfg(not(feature = "async"))] + pub(crate) fn set_timeout_at(&mut self, timeout_at: Option) { + match self { + HttpStream::Unsecured(_, t) => *t = timeout_at, + #[cfg(feature = "rustls")] + HttpStream::Secured(_, t) => *t = timeout_at, + #[cfg(feature = "async")] + HttpStream::Buffer(_) => {} + } + } } fn timeout_err() -> io::Error { @@ -769,6 +779,57 @@ impl Connection { handle_redirects(request, response) }) } + + /// Creates a `Connection` from an existing [`HttpStream`]. + /// + /// Used by [`Client`](crate::Client) to wrap a pooled stream for reuse. + #[cfg(not(feature = "async"))] + pub(crate) fn from_stream(stream: HttpStream) -> Connection { Connection { stream } } + + /// Sends the request and reads the full response, returning: + /// - The [`Response`]. + /// - An [`Option`] for connection reuse (if `Connection: keep-alive` was present). + /// - The [`ParsedRequest`] for potential redirect handling by the caller. + /// + /// Unlike [`send`](Self::send), this method does **not** follow redirects, leaving + /// that responsibility to the caller (the [`Client`](crate::Client)). + #[cfg(not(feature = "async"))] + pub(crate) fn send_for_pool( + mut self, + request: ParsedRequest, + ) -> Result<(Response, Option, ParsedRequest), Error> { + enforce_timeout(request.timeout_at, move || { + let is_head = request.config.method == Method::Head; + let max_body_size = request.config.max_body_size; + + // Update stream timeout for this request + self.stream.set_timeout_at(request.timeout_at); + + // Send request + #[cfg(feature = "log")] + log::trace!("Writing HTTP request (pooled)."); + self.stream.write_all(&request.as_bytes())?; + + // Receive response + #[cfg(feature = "log")] + log::trace!("Reading HTTP response (pooled)."); + let mut response_lazy = ResponseLazy::from_stream( + self.stream, + request.config.max_headers_size, + request.config.max_status_line_len, + request.config.max_body_size, + )?; + + // Set URL on the response from the request + request.url.write_base_url_to(&mut response_lazy.url).unwrap(); + request.url.write_resource_to(&mut response_lazy.url).unwrap(); + + // Read full body and recover stream if keep-alive + let (response, stream) = + Response::create_pooled(response_lazy, is_head, max_body_size)?; + Ok((response, stream, request)) + }) + } } fn handle_redirects( diff --git a/bitreq/src/lib.rs b/bitreq/src/lib.rs index 1f214dfa..dcc4b887 100644 --- a/bitreq/src/lib.rs +++ b/bitreq/src/lib.rs @@ -58,8 +58,9 @@ //! [`send_lazy_async()`](struct.Request.html#method.send_lazy_async) methods //! that return futures for non-blocking operation. //! -//! It also enables [`Client`](struct.Client.html) to reuse TCP connections -//! across requests. +//! When this feature is enabled the [`Client`](struct.Client.html) connection +//! pool uses async connections. Without it, a blocking [`Client`] is provided +//! instead. //! //! ## `async-https` or `async-https-rustls` //! @@ -263,7 +264,7 @@ mod request; mod response; mod url; -#[cfg(feature = "async")] +#[cfg(feature = "std")] pub use client::{Client, RequestExt}; pub use error::*; #[cfg(feature = "proxy")] diff --git a/bitreq/src/response.rs b/bitreq/src/response.rs index b234de10..020f790e 100644 --- a/bitreq/src/response.rs +++ b/bitreq/src/response.rs @@ -3,7 +3,7 @@ use core::str; #[cfg(feature = "async")] use std::future::Future; #[cfg(feature = "std")] -use std::io::{self, BufReader, Bytes, Read}; +use std::io::{self, BufReader, Read}; #[cfg(feature = "async")] use tokio::io::{AsyncRead, AsyncReadExt}; @@ -74,6 +74,49 @@ impl Response { Ok(Response { status_code, reason_phrase, headers, url, body }) } + /// Like [`create`](Self::create), but recovers the underlying [`HttpStream`] for + /// connection reuse when the server indicated `Connection: keep-alive`. + #[cfg(all(feature = "std", not(feature = "async")))] + pub(crate) fn create_pooled( + mut parent: ResponseLazy, + is_head: bool, + max_body_size: Option, + ) -> Result<(Response, Option), Error> { + let mut body = Vec::new(); + if !is_head && parent.status_code != 204 && parent.status_code != 304 { + for byte in &mut parent { + let (byte, length) = byte?; + if max_body_size.is_some_and(|max| body.len().saturating_add(length) > max) { + return Err(Error::BodyOverflow); + } + body.reserve(length); + body.push(byte); + } + } + + let keep_alive = parent + .headers + .get("connection") + .map_or(false, |h| h.eq_ignore_ascii_case("keep-alive")); + + let ResponseLazy { status_code, reason_phrase, headers, url, stream, .. } = parent; + + let recovered_stream = if keep_alive { + let buf_reader = stream.into_buf_reader(); + // Only recover the stream if the BufReader's buffer is empty, ensuring + // no unread bytes remain that would corrupt the next response. + if buf_reader.buffer().is_empty() { + Some(buf_reader.into_inner()) + } else { + None + } + } else { + None + }; + + Ok((Response { status_code, reason_phrase, headers, url, body }, recovered_stream)) + } + #[cfg(feature = "async")] /// Fully read a [`Response`] from an async stream. /// @@ -308,15 +351,51 @@ pub struct ResponseLazy { /// ). pub url: String, - stream: HttpStreamBytes, + stream: StreamBytes, state: HttpStreamState, max_trailing_headers_size: Option, max_body_size: Option, bytes_read: usize, } +/// A byte iterator over an [`HttpStream`] that allows recovering the inner stream. +/// +/// This is equivalent to [`std::io::Bytes`] but provides [`into_buf_reader`] to +/// extract the underlying [`BufReader`] (and ultimately the [`HttpStream`]) after +/// the response has been fully read, enabling connection reuse. +/// +/// [`into_buf_reader`]: StreamBytes::into_buf_reader #[cfg(feature = "std")] -type HttpStreamBytes = Bytes>; +struct StreamBytes { + inner: BufReader, +} + +#[cfg(feature = "std")] +impl StreamBytes { + fn new(stream: HttpStream, capacity: usize) -> Self { + StreamBytes { inner: BufReader::with_capacity(capacity, stream) } + } + + #[cfg(not(feature = "async"))] + fn into_buf_reader(self) -> BufReader { self.inner } +} + +#[cfg(feature = "std")] +impl Iterator for StreamBytes { + type Item = Result; + + fn next(&mut self) -> Option { + let mut byte = 0; + loop { + return match self.inner.read(core::slice::from_mut(&mut byte)) { + Ok(0) => None, + Ok(..) => Some(Ok(byte)), + Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue, + Err(e) => Some(Err(e)), + }; + } + } +} #[cfg(feature = "std")] impl ResponseLazy { @@ -326,7 +405,7 @@ impl ResponseLazy { max_status_line_len: Option, max_body_size: Option, ) -> Result { - let mut stream = BufReader::with_capacity(BACKING_READ_BUFFER_LENGTH, stream).bytes(); + let mut stream = StreamBytes::new(stream, BACKING_READ_BUFFER_LENGTH); let ResponseMetadata { status_code, reason_phrase, @@ -356,7 +435,7 @@ impl ResponseLazy { reason_phrase: response.reason_phrase, headers: response.headers, url: response.url, - stream: BufReader::with_capacity(1, http_stream).bytes(), + stream: StreamBytes::new(http_stream, 1), state: HttpStreamState::EndOnClose, max_trailing_headers_size: None, // Body was already fully loaded and size-checked by send_async @@ -700,7 +779,7 @@ macro_rules! define_read_methods { } #[cfg(feature = "std")] -define_read_methods!((read_until_closed, read_with_content_length, read_trailers, read_chunked, read_metadata, read_line)<>, HttpStreamBytes); +define_read_methods!((read_until_closed, read_with_content_length, read_trailers, read_chunked, read_metadata, read_line)<>, StreamBytes); #[cfg(feature = "async")] define_read_methods!((read_until_closed_async, read_with_content_length_async, read_trailers_async, read_chunked_async, read_metadata_async, read_line_async), R, async, await); diff --git a/bitreq/tests/setup.rs b/bitreq/tests/setup.rs index 234da30a..71762b70 100644 --- a/bitreq/tests/setup.rs +++ b/bitreq/tests/setup.rs @@ -201,7 +201,6 @@ pub fn setup() { pub fn url(req: &str) -> String { format!("http://localhost:35562{}", req) } -#[cfg(feature = "async")] static CLIENT: std::sync::OnceLock = std::sync::OnceLock::new(); #[cfg(feature = "async")] static RUNTIME: std::sync::OnceLock = std::sync::OnceLock::new(); @@ -224,6 +223,26 @@ pub async fn maybe_make_request( (res, lazy_res) => panic!("{res:?} != {}", lazy_res.is_err()), } + // Test blocking Client path + #[cfg(not(feature = "async"))] + { + let client = CLIENT.get_or_init(|| bitreq::Client::new(100)); + let client_response = client.send(request.clone()); + match (&response, client_response) { + (Ok(resp), Ok(client_resp)) => { + assert_eq!(client_resp.status_code, resp.status_code); + assert_eq!(client_resp.reason_phrase, resp.reason_phrase); + assert_eq!(client_resp.as_bytes(), resp.as_bytes()); + } + (Err(e), Err(client_e)) => { + assert_eq!(format!("{e:?}"), format!("{client_e:?}")); + } + (res, client_res) => { + panic!("{res:?} != {client_res:?}"); + } + } + } + #[cfg(feature = "async")] { if let Ok(resp) = &response {