diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index d2500359..bb716933 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -35,7 +35,7 @@ use std::{ }; use deadpool::{async_trait, managed}; -use tokio::spawn; +use tokio::{spawn, task::JoinHandle}; use tokio_postgres::{ tls::MakeTlsConnect, tls::TlsConnect, types::Type, Client as PgClient, Config as PgConfig, Error, IsolationLevel, Socket, Statement, Transaction as PgTransaction, @@ -125,8 +125,8 @@ impl managed::Manager for Manager { type Error = Error; async fn create(&self) -> Result { - let client = self.connect.connect(&self.pg_config).await?; - let client_wrapper = ClientWrapper::new(client); + let (client, conn_task) = self.connect.connect(&self.pg_config).await?; + let client_wrapper = ClientWrapper::new(client, conn_task); self.statement_caches .attach(&client_wrapper.statement_cache); Ok(client_wrapper) @@ -156,7 +156,7 @@ impl managed::Manager for Manager { #[async_trait] trait Connect: Sync + Send { - async fn connect(&self, pg_config: &PgConfig) -> Result; + async fn connect(&self, pg_config: &PgConfig) -> Result<(PgClient, JoinHandle<()>), Error>; } struct ConnectImpl @@ -177,14 +177,14 @@ where T::TlsConnect: Sync + Send, >::Future: Send, { - async fn connect(&self, pg_config: &PgConfig) -> Result { + async fn connect(&self, pg_config: &PgConfig) -> Result<(PgClient, JoinHandle<()>), Error> { let (client, connection) = pg_config.connect(self.tls.clone()).await?; - drop(spawn(async move { + let conn_task = spawn(async move { if let Err(e) = connection.await { log::warn!(target: "deadpool.postgres", "Connection error: {}", e); } - })); - Ok(client) + }); + Ok((client, conn_task)) } } @@ -369,17 +369,22 @@ pub struct ClientWrapper { /// Original [`PgClient`]. client: PgClient, + /// A handle to the connection task that should be aborted when the client + /// wrapper is dropped. + conn_task: JoinHandle<()>, + /// [`StatementCache`] of this client. pub statement_cache: Arc, } impl ClientWrapper { /// Create a new [`ClientWrapper`] instance using the given - /// [`tokio_postgres::Client`]. + /// [`tokio_postgres::Client`] and handle to the connection task. #[must_use] - pub fn new(client: PgClient) -> Self { + pub fn new(client: PgClient, conn_task: JoinHandle<()>) -> Self { Self { client, + conn_task, statement_cache: Arc::new(StatementCache::new()), } } @@ -436,6 +441,12 @@ impl DerefMut for ClientWrapper { } } +impl Drop for ClientWrapper { + fn drop(&mut self) { + self.conn_task.abort() + } +} + /// Wrapper around [`tokio_postgres::Transaction`] with a [`StatementCache`] /// from the [`Client`] object it was created by. pub struct Transaction<'a> {