diff --git a/postgres/src/lib.rs b/postgres/src/lib.rs index 56b271a4..a6c79cb8 100644 --- a/postgres/src/lib.rs +++ b/postgres/src/lib.rs @@ -25,7 +25,7 @@ mod config; use std::{ borrow::Cow, collections::HashMap, - fmt, + error, fmt, ops::{Deref, DerefMut}, sync::{ atomic::{AtomicUsize, Ordering}, @@ -63,12 +63,35 @@ pub type Client = Object; type RecycleResult = deadpool::managed::RecycleResult; type RecycleError = deadpool::managed::RecycleError; +/// Allows dynamic configuration for new database connections. +#[async_trait] +pub trait ConfigSource +where + Self: Sync + Send + fmt::Debug, +{ + /// Returns the current [`PgConfig`]. + /// Called to get the configuration for each new connection. + async fn get_config(&self) -> Result, Box>; +} + +#[derive(Debug)] +struct StaticConfigSource { + pg_config: Arc, +} + +#[async_trait] +impl ConfigSource for StaticConfigSource { + async fn get_config(&self) -> Result, Box> { + Ok(self.pg_config.clone()) + } +} + /// [`Manager`] for creating and recycling PostgreSQL connections. /// /// [`Manager`]: managed::Manager pub struct Manager { config: ManagerConfig, - pg_config: PgConfig, + pg_config: Box, connect: Box, /// [`StatementCaches`] of [`Client`]s handed out by the [`Pool`]. pub statement_caches: StatementCaches, @@ -98,7 +121,27 @@ impl Manager { { Self { config, - pg_config, + pg_config: Box::new(StaticConfigSource { + pg_config: Arc::new(pg_config), + }), + connect: Box::new(ConnectImpl { tls }), + statement_caches: StatementCaches::default(), + } + } + + /// Create a new [`Manager`] that allows dynamic configuration using + /// the given [`ConfigSource`]. + pub fn from_config_source(pg_config: C, tls: T, config: ManagerConfig) -> Self + where + T: MakeTlsConnect + Clone + Sync + Send + 'static, + T::Stream: Sync + Send, + T::TlsConnect: Sync + Send, + >::Future: Send, + C: ConfigSource + 'static, + { + Self { + config, + pg_config: Box::new(pg_config), connect: Box::new(ConnectImpl { tls }), statement_caches: StatementCaches::default(), } @@ -122,7 +165,15 @@ impl managed::Manager for Manager { type Error = Error; async fn create(&self) -> Result { - let client = self.connect.connect(&self.pg_config).await?; + let client = self + .connect + .connect( + self.pg_config + .get_config() + .await + .map_err(|e| Error::new_external(e))?, + ) + .await?; let client_wrapper = ClientWrapper::new(client); self.statement_caches .attach(&client_wrapper.statement_cache); @@ -153,7 +204,7 @@ impl managed::Manager for Manager { #[async_trait] trait Connect: Sync + Send { - async fn connect(&self, pg_config: &PgConfig) -> Result; + async fn connect(&self, pg_config: Arc) -> Result; } struct ConnectImpl @@ -174,7 +225,7 @@ where T::TlsConnect: Sync + Send, >::Future: Send, { - async fn connect(&self, pg_config: &PgConfig) -> Result { + async fn connect(&self, pg_config: Arc) -> Result { let (client, connection) = pg_config.connect(self.tls.clone()).await?; drop(spawn(async move { if let Err(e) = connection.await {