Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions src/database/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{sync::Arc, time::Duration};

#[cfg(not(feature = "sync"))]
use futures_util::future::BoxFuture;
#[cfg(feature = "sqlx-mysql")]
use sqlx::mysql::MySqlConnectOptions;
Expand Down Expand Up @@ -53,6 +54,29 @@ pub struct Database;
#[cfg(feature = "sync")]
type BoxFuture<'a, T> = T;

#[cfg(feature = "sqlx-mysql")]
type MapMySqlPoolOptsFn = Arc<
dyn Fn(sqlx::pool::PoolOptions<sqlx::MySql>) -> sqlx::pool::PoolOptions<sqlx::MySql>
+ Send
+ Sync,
>;

#[cfg(feature = "sqlx-postgres")]
type MapPgPoolOptsFn = Arc<
dyn Fn(sqlx::pool::PoolOptions<sqlx::Postgres>) -> sqlx::pool::PoolOptions<sqlx::Postgres>
+ Send
+ Sync,
>;

#[cfg(feature = "sqlx-sqlite")]
type MapSqlitePoolOptsFn = Option<
Arc<
dyn Fn(sqlx::pool::PoolOptions<sqlx::Sqlite>) -> sqlx::pool::PoolOptions<sqlx::Sqlite>
+ Send
+ Sync,
>,
>;

type AfterConnectCallback = Option<
Arc<
dyn Fn(DatabaseConnection) -> BoxFuture<'static, Result<(), DbErr>> + Send + Sync + 'static,
Expand Down Expand Up @@ -98,6 +122,15 @@ pub struct ConnectOptions {
#[debug(skip)]
pub(crate) after_connect: AfterConnectCallback,

#[cfg(feature = "sqlx-mysql")]
#[debug(skip)]
pub(crate) mysql_pool_opts_fn: Option<MapMySqlPoolOptsFn>,
#[cfg(feature = "sqlx-postgres")]
#[debug(skip)]
pub(crate) pg_pool_opts_fn: Option<MapPgPoolOptsFn>,
#[cfg(feature = "sqlx-sqlite")]
#[debug(skip)]
pub(crate) sqlite_pool_opts_fn: MapSqlitePoolOptsFn,
#[cfg(feature = "sqlx-mysql")]
#[debug(skip)]
pub(crate) mysql_opts_fn:
Expand Down Expand Up @@ -218,6 +251,12 @@ impl ConnectOptions {
connect_lazy: false,
after_connect: None,
#[cfg(feature = "sqlx-mysql")]
mysql_pool_opts_fn: None,
#[cfg(feature = "sqlx-postgres")]
pg_pool_opts_fn: None,
#[cfg(feature = "sqlx-sqlite")]
sqlite_pool_opts_fn: None,
#[cfg(feature = "sqlx-mysql")]
mysql_opts_fn: None,
#[cfg(feature = "sqlx-postgres")]
pg_opts_fn: None,
Expand Down Expand Up @@ -404,6 +443,21 @@ impl ConnectOptions {
self
}

#[cfg(feature = "sqlx-mysql")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqlx-mysql")))]
/// Apply a function to modify the underlying [`sqlx::pool::PoolOptions<sqlx::MySql>`]
/// before creating the connection pool.
pub fn map_sqlx_mysql_pool_opts<F>(&mut self, f: F) -> &mut Self
where
F: Fn(sqlx::pool::PoolOptions<sqlx::MySql>) -> sqlx::pool::PoolOptions<sqlx::MySql>
+ Send
+ Sync
+ 'static,
{
self.mysql_pool_opts_fn = Some(Arc::new(f));
self
}

#[cfg(feature = "sqlx-postgres")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
/// Apply a function to modify the underlying [`PgConnectOptions`] before
Expand All @@ -416,6 +470,21 @@ impl ConnectOptions {
self
}

#[cfg(feature = "sqlx-postgres")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqlx-postgres")))]
/// Apply a function to modify the underlying [`sqlx::pool::PoolOptions<sqlx::Postgres>`]
/// before creating the connection pool.
pub fn map_sqlx_postgres_pool_opts<F>(&mut self, f: F) -> &mut Self
where
F: Fn(sqlx::pool::PoolOptions<sqlx::Postgres>) -> sqlx::pool::PoolOptions<sqlx::Postgres>
+ Send
+ Sync
+ 'static,
{
self.pg_pool_opts_fn = Some(Arc::new(f));
self
}

#[cfg(feature = "sqlx-sqlite")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
/// Apply a function to modify the underlying [`SqliteConnectOptions`] before
Expand All @@ -427,4 +496,19 @@ impl ConnectOptions {
self.sqlite_opts_fn = Some(Arc::new(f));
self
}

#[cfg(feature = "sqlx-sqlite")]
#[cfg_attr(docsrs, doc(cfg(feature = "sqlx-sqlite")))]
/// Apply a function to modify the underlying [`sqlx::pool::PoolOptions<sqlx::Sqlite>`]
/// before creating the connection pool.
pub fn map_sqlx_sqlite_pool_opts<F>(&mut self, f: F) -> &mut Self
where
F: Fn(sqlx::pool::PoolOptions<sqlx::Sqlite>) -> sqlx::pool::PoolOptions<sqlx::Sqlite>
+ Send
+ Sync
+ 'static,
{
self.sqlite_pool_opts_fn = Some(Arc::new(f));
self
}
}
15 changes: 9 additions & 6 deletions src/driver/sqlx_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,17 @@ impl SqlxMySqlConnector {
if let Some(f) = &options.mysql_opts_fn {
sqlx_opts = f(sqlx_opts);
}

let after_connect = options.after_connect.clone();

let pool = if options.connect_lazy {
options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
let connect_lazy = options.connect_lazy;
let mysql_pool_opts_fn = options.mysql_pool_opts_fn.clone();
let mut pool_options = options.sqlx_pool_options();
if let Some(f) = &mysql_pool_opts_fn {
pool_options = f(pool_options);
}
let pool = if connect_lazy {
pool_options.connect_lazy_with(sqlx_opts)
} else {
options
.sqlx_pool_options()
pool_options
.connect_with(sqlx_opts)
.await
.map_err(sqlx_error_to_conn_err)?
Expand Down
5 changes: 4 additions & 1 deletion src/driver/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl SqlxPostgresConnector {

let lazy = options.connect_lazy;
let after_connect = options.after_connect.clone();
let pg_pool_opts_fn = options.pg_pool_opts_fn.clone();
let mut pool_options = options.sqlx_pool_options();

if let Some(sql) = set_search_path_sql {
Expand All @@ -115,7 +116,9 @@ impl SqlxPostgresConnector {
})
});
}

if let Some(f) = &pg_pool_opts_fn {
pool_options = f(pool_options);
}
let pool = if lazy {
pool_options.connect_lazy_with(sqlx_opts)
} else {
Expand Down
14 changes: 10 additions & 4 deletions src/driver/sqlx_sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,18 @@ impl SqlxSqliteConnector {
}

let after_conn = options.after_connect.clone();
let connect_lazy = options.connect_lazy;
let sqlite_pool_opts_fn = options.sqlite_pool_opts_fn.clone();
let mut pool_options = options.sqlx_pool_options();

let pool = if options.connect_lazy {
options.sqlx_pool_options().connect_lazy_with(sqlx_opts)
if let Some(f) = &sqlite_pool_opts_fn {
pool_options = f(pool_options);
}

let pool = if connect_lazy {
pool_options.connect_lazy_with(sqlx_opts)
} else {
options
.sqlx_pool_options()
pool_options
.connect_with(sqlx_opts)
.await
.map_err(sqlx_error_to_conn_err)?
Expand Down
Loading