Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 192 additions & 10 deletions bitreq/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
//! Connection pooling client for HTTP requests.
//! Connection pooling [`Client`] for HTTP requests.
//!
//! The `Client` caches connections to avoid repeated TCP handshakes and TLS negotiations.
//! The [`Client`] caches connections to avoid repeated TCP handshakes and TLS negotiations.
//!
//! Due to std limitations, `Client` currently only supports async requests.

#![cfg(feature = "async")]
//! When the `async` feature is enabled, the client uses async connections via `tokio`.
//! Otherwise a blocking client backed by `std::net::TcpStream` is provided.

use std::collections::{hash_map, HashMap, VecDeque};
use std::sync::{Arc, Mutex};

// ---------------------------------------------------------------------------
// Async Client (feature = "async")
// ---------------------------------------------------------------------------
#[cfg(feature = "async")]
use crate::connection::AsyncConnection;
use crate::request::{OwnedConnectionParams as ConnectionKey, ParsedRequest};
use crate::{Error, Request, Response};
Expand All @@ -30,17 +33,20 @@ use crate::{Error, Request, Response};
/// .await;
/// # }
/// ```
#[cfg(feature = "async")]
#[derive(Clone)]
pub struct Client {
r#async: Arc<Mutex<ClientImpl<AsyncConnection>>>,
r#async: Arc<Mutex<AsyncClientState>>,
}

struct ClientImpl<T> {
connections: HashMap<ConnectionKey, Arc<T>>,
#[cfg(feature = "async")]
struct AsyncClientState {
connections: HashMap<ConnectionKey, Arc<AsyncConnection>>,
lru_order: VecDeque<ConnectionKey>,
capacity: usize,
}

#[cfg(feature = "async")]
impl Client {
/// Creates a new `Client` with the specified connection cache capacity.
///
Expand All @@ -50,7 +56,7 @@ impl Client {
/// reached, the least recently used connection is evicted.
pub fn new(capacity: usize) -> Self {
Client {
r#async: Arc::new(Mutex::new(ClientImpl {
r#async: Arc::new(Mutex::new(AsyncClientState {
connections: HashMap::new(),
lru_order: VecDeque::new(),
capacity,
Expand Down Expand Up @@ -98,7 +104,8 @@ impl Client {
}
}

/// Extension trait for `Request` to use with `Client`.
/// Extension trait for [`Request`] to use with [`Client`].
#[cfg(feature = "async")]
pub trait RequestExt {
/// Sends this request asynchronously using the provided client's connection pool.
fn send_async_with_client(
Expand All @@ -107,6 +114,7 @@ pub trait RequestExt {
) -> impl std::future::Future<Output = Result<Response, Error>>;
}

#[cfg(feature = "async")]
impl RequestExt for Request {
fn send_async_with_client(
self,
Expand All @@ -115,3 +123,177 @@ impl RequestExt for Request {
client.send_async(self)
}
}

// ---------------------------------------------------------------------------
// Blocking Client (no "async" feature)
// ---------------------------------------------------------------------------

#[cfg(not(feature = "async"))]
use core::time::Duration;
#[cfg(not(feature = "async"))]
use std::time::Instant;

#[cfg(not(feature = "async"))]
use crate::connection::{Connection, HttpStream};
#[cfg(not(feature = "async"))]
use crate::Method;

#[cfg(not(feature = "async"))]
struct PoolEntry {
stream: HttpStream,
expires_at: Instant,
}

/// A client that caches connections for reuse.
///
/// The client maintains a pool of up to `capacity` connections, evicting
/// the least recently used connection when the cache is full. A cached
/// connection is reused when the server indicated `Connection: keep-alive`
/// and the keep-alive timeout has not yet expired.
///
/// # Example
///
/// ```no_run
/// # fn main() -> Result<(), bitreq::Error> {
/// use bitreq::{Client, RequestExt};
///
/// let client = Client::new(10); // Cache up to 10 connections
/// let response = bitreq::get("http://example.com")
/// .send_with_client(&client)?;
/// # Ok(()) }
/// ```
#[cfg(not(feature = "async"))]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should never disable code based on a feature being added.

#[derive(Clone)]
pub struct Client {
state: Arc<Mutex<BlockingClientState>>,
}

#[cfg(not(feature = "async"))]
struct BlockingClientState {
connections: HashMap<ConnectionKey, PoolEntry>,
lru_order: VecDeque<ConnectionKey>,
capacity: usize,
}

#[cfg(not(feature = "async"))]
impl Client {
/// Creates a new `Client` with the specified connection cache capacity.
///
/// # Arguments
///
/// * `capacity` - Maximum number of cached connections. When this limit is
/// reached, the least recently used connection is evicted.
pub fn new(capacity: usize) -> Self {
Client {
state: Arc::new(Mutex::new(BlockingClientState {
connections: HashMap::new(),
lru_order: VecDeque::new(),
capacity,
})),
}
}

/// Sends a request using a cached connection if available.
pub fn send(&self, request: Request) -> Result<Response, Error> {
let parsed_request = ParsedRequest::new(request)?;
self.send_inner(parsed_request)
}

fn send_inner(&self, mut request: ParsedRequest) -> Result<Response, Error> {
loop {
let key: ConnectionKey = request.connection_params().into();

// Get cached stream or create new connection
let connection = match self.take_stream(&key) {
Some(stream) => Connection::from_stream(stream),
None => Connection::new(request.connection_params(), request.timeout_at)?,
};

let (response, stream, req) = connection.send_for_pool(request)?;
request = req;

// Cache stream if keep-alive, with expiry
if let Some(stream) = stream {
let expires_at = Self::parse_keep_alive_timeout(response.headers.get("keep-alive"));
self.put_stream(key, stream, expires_at);
}

// Handle redirects
match response.status_code {
301 | 302 | 303 | 307 => {
let location = response
.headers
.get("location")
.ok_or(Error::RedirectLocationMissing)?
.clone();
request.redirect_to(&location)?;
if response.status_code == 303 {
match request.config.method {
Method::Post | Method::Put | Method::Delete => {
request.config.method = Method::Get;
}
_ => {}
}
}
continue;
}
_ => return Ok(response),
}
}
}

fn take_stream(&self, key: &ConnectionKey) -> Option<HttpStream> {
let mut state = self.state.lock().unwrap();
if let Some(entry) = state.connections.remove(key) {
// Remove from LRU order
if let Some(pos) = state.lru_order.iter().position(|k| k == key) {
state.lru_order.remove(pos);
}
if entry.expires_at > Instant::now() {
return Some(entry.stream);
}
}
None
}

fn put_stream(&self, key: ConnectionKey, stream: HttpStream, expires_at: Instant) {
let mut state = self.state.lock().unwrap();
if let hash_map::Entry::Vacant(entry) = state.connections.entry(key.clone()) {
entry.insert(PoolEntry { stream, expires_at });
state.lru_order.push_back(key);
if state.connections.len() > state.capacity {
if let Some(oldest_key) = state.lru_order.pop_front() {
state.connections.remove(&oldest_key);
}
}
}
}

fn parse_keep_alive_timeout(keep_alive_header: Option<&String>) -> Instant {
let default_timeout = Instant::now() + Duration::from_secs(60);
if let Some(header) = keep_alive_header {
for param in header.split(',') {
if let Some((k, v)) = param.trim().split_once('=') {
if k.trim() == "timeout" {
if let Ok(secs) = v.parse::<u64>() {
return Instant::now() + Duration::from_secs(secs.saturating_sub(1));
}
}
}
}
}
default_timeout
}
}

/// Extension trait for [`Request`] to use with [`Client`].
#[cfg(not(feature = "async"))]
pub trait RequestExt {
/// Sends this request using the provided client's connection pool.
fn send_with_client(self, client: &Client) -> Result<Response, Error>;
}

#[cfg(not(feature = "async"))]
impl RequestExt for Request {
fn send_with_client(self, client: &Client) -> Result<Response, Error> { client.send(self) }
}
67 changes: 64 additions & 3 deletions bitreq/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ use tokio::net::TcpStream as AsyncTcpStream;
use tokio::sync::Mutex as AsyncMutex;

use crate::request::{ConnectionParams, OwnedConnectionParams, ParsedRequest};
#[cfg(feature = "async")]
use crate::Response;
use crate::{Error, Method, ResponseLazy};
use crate::{Error, Method, Response, ResponseLazy};

type UnsecuredStream = TcpStream;

Expand All @@ -51,6 +49,18 @@ impl HttpStream {
pub(crate) fn create_buffer(buffer: Vec<u8>) -> HttpStream {
HttpStream::Buffer(std::io::Cursor::new(buffer))
}

/// Updates the timeout deadline used for read/write operations on this stream.
#[cfg(not(feature = "async"))]
pub(crate) fn set_timeout_at(&mut self, timeout_at: Option<Instant>) {
match self {
HttpStream::Unsecured(_, t) => *t = timeout_at,
#[cfg(feature = "rustls")]
HttpStream::Secured(_, t) => *t = timeout_at,
#[cfg(feature = "async")]
HttpStream::Buffer(_) => {}
}
}
}

fn timeout_err() -> io::Error {
Expand Down Expand Up @@ -769,6 +779,57 @@ impl Connection {
handle_redirects(request, response)
})
}

/// Creates a `Connection` from an existing [`HttpStream`].
///
/// Used by [`Client`](crate::Client) to wrap a pooled stream for reuse.
#[cfg(not(feature = "async"))]
pub(crate) fn from_stream(stream: HttpStream) -> Connection { Connection { stream } }

/// Sends the request and reads the full response, returning:
/// - The [`Response`].
/// - An [`Option<HttpStream>`] for connection reuse (if `Connection: keep-alive` was present).
/// - The [`ParsedRequest`] for potential redirect handling by the caller.
///
/// Unlike [`send`](Self::send), this method does **not** follow redirects, leaving
/// that responsibility to the caller (the [`Client`](crate::Client)).
#[cfg(not(feature = "async"))]
pub(crate) fn send_for_pool(
mut self,
request: ParsedRequest,
) -> Result<(Response, Option<HttpStream>, ParsedRequest), Error> {
enforce_timeout(request.timeout_at, move || {
let is_head = request.config.method == Method::Head;
let max_body_size = request.config.max_body_size;

// Update stream timeout for this request
self.stream.set_timeout_at(request.timeout_at);

// Send request
#[cfg(feature = "log")]
log::trace!("Writing HTTP request (pooled).");
self.stream.write_all(&request.as_bytes())?;

// Receive response
#[cfg(feature = "log")]
log::trace!("Reading HTTP response (pooled).");
let mut response_lazy = ResponseLazy::from_stream(
self.stream,
request.config.max_headers_size,
request.config.max_status_line_len,
request.config.max_body_size,
)?;

// Set URL on the response from the request
request.url.write_base_url_to(&mut response_lazy.url).unwrap();
request.url.write_resource_to(&mut response_lazy.url).unwrap();

// Read full body and recover stream if keep-alive
let (response, stream) =
Response::create_pooled(response_lazy, is_head, max_body_size)?;
Ok((response, stream, request))
})
}
}

fn handle_redirects(
Expand Down
7 changes: 4 additions & 3 deletions bitreq/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@
//! [`send_lazy_async()`](struct.Request.html#method.send_lazy_async) methods
//! that return futures for non-blocking operation.
//!
//! It also enables [`Client`](struct.Client.html) to reuse TCP connections
//! across requests.
//! When this feature is enabled the [`Client`](struct.Client.html) connection
//! pool uses async connections. Without it, a blocking [`Client`] is provided
//! instead.
//!
//! ## `async-https` or `async-https-rustls`
//!
Expand Down Expand Up @@ -263,7 +264,7 @@ mod request;
mod response;
mod url;

#[cfg(feature = "async")]
#[cfg(feature = "std")]
pub use client::{Client, RequestExt};
pub use error::*;
#[cfg(feature = "proxy")]
Expand Down
Loading
Loading