Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 21 additions & 10 deletions postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use std::{
};

use deadpool::{async_trait, managed};
use tokio::spawn;
use tokio::{spawn, task::JoinHandle};
use tokio_postgres::{
tls::MakeTlsConnect, tls::TlsConnect, types::Type, Client as PgClient, Config as PgConfig,
Error, IsolationLevel, Socket, Statement, Transaction as PgTransaction,
Expand Down Expand Up @@ -125,8 +125,8 @@ impl managed::Manager for Manager {
type Error = Error;

async fn create(&self) -> Result<ClientWrapper, Error> {
let client = self.connect.connect(&self.pg_config).await?;
let client_wrapper = ClientWrapper::new(client);
let (client, conn_task) = self.connect.connect(&self.pg_config).await?;
let client_wrapper = ClientWrapper::new(client, conn_task);
self.statement_caches
.attach(&client_wrapper.statement_cache);
Ok(client_wrapper)
Expand Down Expand Up @@ -156,7 +156,7 @@ impl managed::Manager for Manager {

#[async_trait]
trait Connect: Sync + Send {
async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error>;
async fn connect(&self, pg_config: &PgConfig) -> Result<(PgClient, JoinHandle<()>), Error>;
}

struct ConnectImpl<T>
Expand All @@ -177,14 +177,14 @@ where
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error> {
async fn connect(&self, pg_config: &PgConfig) -> Result<(PgClient, JoinHandle<()>), Error> {
let (client, connection) = pg_config.connect(self.tls.clone()).await?;
drop(spawn(async move {
let conn_task = spawn(async move {
if let Err(e) = connection.await {
log::warn!(target: "deadpool.postgres", "Connection error: {}", e);
}
}));
Ok(client)
});
Ok((client, conn_task))
}
}

Expand Down Expand Up @@ -369,17 +369,22 @@ pub struct ClientWrapper {
/// Original [`PgClient`].
client: PgClient,

/// A handle to the connection task that should be aborted when the client
/// wrapper is dropped.
conn_task: JoinHandle<()>,

/// [`StatementCache`] of this client.
pub statement_cache: Arc<StatementCache>,
}

impl ClientWrapper {
/// Create a new [`ClientWrapper`] instance using the given
/// [`tokio_postgres::Client`].
/// [`tokio_postgres::Client`] and handle to the connection task.
#[must_use]
pub fn new(client: PgClient) -> Self {
pub fn new(client: PgClient, conn_task: JoinHandle<()>) -> Self {
Self {
client,
conn_task,
statement_cache: Arc::new(StatementCache::new()),
}
}
Expand Down Expand Up @@ -436,6 +441,12 @@ impl DerefMut for ClientWrapper {
}
}

impl Drop for ClientWrapper {
fn drop(&mut self) {
self.conn_task.abort()
}
}

/// Wrapper around [`tokio_postgres::Transaction`] with a [`StatementCache`]
/// from the [`Client`] object it was created by.
pub struct Transaction<'a> {
Expand Down