diff --git a/src/database/mod.rs b/src/database/mod.rs index 16d6d524c..6092362d8 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,5 +1,6 @@ use std::{sync::Arc, time::Duration}; +#[cfg(not(feature = "sync"))] use futures_util::future::BoxFuture; #[cfg(feature = "sqlx-mysql")] use sqlx::mysql::MySqlConnectOptions; @@ -53,6 +54,29 @@ pub struct Database; #[cfg(feature = "sync")] type BoxFuture<'a, T> = T; +#[cfg(feature = "sqlx-mysql")] +type MapMySqlPoolOptsFn = Arc< + dyn Fn(sqlx::pool::PoolOptions) -> sqlx::pool::PoolOptions + + Send + + Sync, +>; + +#[cfg(feature = "sqlx-postgres")] +type MapPgPoolOptsFn = Arc< + dyn Fn(sqlx::pool::PoolOptions) -> sqlx::pool::PoolOptions + + Send + + Sync, +>; + +#[cfg(feature = "sqlx-sqlite")] +type MapSqlitePoolOptsFn = Option< + Arc< + dyn Fn(sqlx::pool::PoolOptions) -> sqlx::pool::PoolOptions + + Send + + Sync, + >, +>; + type AfterConnectCallback = Option< Arc< dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static, @@ -98,6 +122,15 @@ pub struct ConnectOptions { #[debug(skip)] pub(crate) after_connect: AfterConnectCallback, + #[cfg(feature = "sqlx-mysql")] + #[debug(skip)] + pub(crate) mysql_pool_opts_fn: Option, + #[cfg(feature = "sqlx-postgres")] + #[debug(skip)] + pub(crate) pg_pool_opts_fn: Option, + #[cfg(feature = "sqlx-sqlite")] + #[debug(skip)] + pub(crate) sqlite_pool_opts_fn: MapSqlitePoolOptsFn, #[cfg(feature = "sqlx-mysql")] #[debug(skip)] pub(crate) mysql_opts_fn: @@ -218,6 +251,12 @@ impl ConnectOptions { connect_lazy: false, after_connect: None, #[cfg(feature = "sqlx-mysql")] + mysql_pool_opts_fn: None, + #[cfg(feature = "sqlx-postgres")] + pg_pool_opts_fn: None, + #[cfg(feature = "sqlx-sqlite")] + sqlite_pool_opts_fn: None, + #[cfg(feature = "sqlx-mysql")] mysql_opts_fn: None, #[cfg(feature = "sqlx-postgres")] pg_opts_fn: None, @@ -404,6 +443,21 @@ impl ConnectOptions { self } + #[cfg(feature = "sqlx-mysql")] + #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))] + /// Apply a function to modify the underlying [`sqlx::pool::PoolOptions`] + /// before creating the connection pool. + pub fn map_sqlx_mysql_pool_opts(&mut self, f: F) -> &mut Self + where + F: Fn(sqlx::pool::PoolOptions) -> sqlx::pool::PoolOptions + + Send + + Sync + + 'static, + { + self.mysql_pool_opts_fn = Some(Arc::new(f)); + self + } + #[cfg(feature = "sqlx-postgres")] #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))] /// Apply a function to modify the underlying [`PgConnectOptions`] before @@ -416,6 +470,21 @@ impl ConnectOptions { self } + #[cfg(feature = "sqlx-postgres")] + #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))] + /// Apply a function to modify the underlying [`sqlx::pool::PoolOptions`] + /// before creating the connection pool. + pub fn map_sqlx_postgres_pool_opts(&mut self, f: F) -> &mut Self + where + F: Fn(sqlx::pool::PoolOptions) -> sqlx::pool::PoolOptions + + Send + + Sync + + 'static, + { + self.pg_pool_opts_fn = Some(Arc::new(f)); + self + } + #[cfg(feature = "sqlx-sqlite")] #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))] /// Apply a function to modify the underlying [`SqliteConnectOptions`] before @@ -427,4 +496,19 @@ impl ConnectOptions { self.sqlite_opts_fn = Some(Arc::new(f)); self } + + #[cfg(feature = "sqlx-sqlite")] + #[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))] + /// Apply a function to modify the underlying [`sqlx::pool::PoolOptions`] + /// before creating the connection pool. + pub fn map_sqlx_sqlite_pool_opts(&mut self, f: F) -> &mut Self + where + F: Fn(sqlx::pool::PoolOptions) -> sqlx::pool::PoolOptions + + Send + + Sync + + 'static, + { + self.sqlite_pool_opts_fn = Some(Arc::new(f)); + self + } } diff --git a/src/driver/sqlx_mysql.rs b/src/driver/sqlx_mysql.rs index 0649ac0e0..39583a6ce 100644 --- a/src/driver/sqlx_mysql.rs +++ b/src/driver/sqlx_mysql.rs @@ -81,14 +81,17 @@ impl SqlxMySqlConnector { if let Some(f) = &options.mysql_opts_fn { sqlx_opts = f(sqlx_opts); } - let after_connect = options.after_connect.clone(); - - let pool = if options.connect_lazy { - options.sqlx_pool_options().connect_lazy_with(sqlx_opts) + let connect_lazy = options.connect_lazy; + let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone(); + let mut pool_options = options.sqlx_pool_options(); + if let Some(f) = &mysql_pool_opts_fn { + pool_options = f(pool_options); + } + let pool = if connect_lazy { + pool_options.connect_lazy_with(sqlx_opts) } else { - options - .sqlx_pool_options() + pool_options .connect_with(sqlx_opts) .await .map_err(sqlx_error_to_conn_err)? diff --git a/src/driver/sqlx_postgres.rs b/src/driver/sqlx_postgres.rs index e6b86c50c..324b5e665 100644 --- a/src/driver/sqlx_postgres.rs +++ b/src/driver/sqlx_postgres.rs @@ -103,6 +103,7 @@ impl SqlxPostgresConnector { let lazy = options.connect_lazy; let after_connect = options.after_connect.clone(); + let pg_pool_opts_fn = options.pg_pool_opts_fn.clone(); let mut pool_options = options.sqlx_pool_options(); if let Some(sql) = set_search_path_sql { @@ -115,7 +116,9 @@ impl SqlxPostgresConnector { }) }); } - + if let Some(f) = &pg_pool_opts_fn { + pool_options = f(pool_options); + } let pool = if lazy { pool_options.connect_lazy_with(sqlx_opts) } else { diff --git a/src/driver/sqlx_sqlite.rs b/src/driver/sqlx_sqlite.rs index cb17bbc56..6e19c2d74 100644 --- a/src/driver/sqlx_sqlite.rs +++ b/src/driver/sqlx_sqlite.rs @@ -91,12 +91,18 @@ impl SqlxSqliteConnector { } let after_conn = options.after_connect.clone(); + let connect_lazy = options.connect_lazy; + let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone(); + let mut pool_options = options.sqlx_pool_options(); - let pool = if options.connect_lazy { - options.sqlx_pool_options().connect_lazy_with(sqlx_opts) + if let Some(f) = &sqlite_pool_opts_fn { + pool_options = f(pool_options); + } + + let pool = if connect_lazy { + pool_options.connect_lazy_with(sqlx_opts) } else { - options - .sqlx_pool_options() + pool_options .connect_with(sqlx_opts) .await .map_err(sqlx_error_to_conn_err)?