Skip to content

Commit 7e825eb

Browse files
Merge pull request #1426 from kushalbakshi/master
feat: add database.dbname setting for PostgreSQL connections
2 parents 367f5d1 + eabcfb0 commit 7e825eb

File tree

4 files changed

+122
-23
lines changed

4 files changed

+122
-23
lines changed

src/datajoint/connection.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
port: int | None = None,
169169
use_tls: bool | dict | None = None,
170170
*,
171+
database_name: str | None = None,
171172
backend: str | None = None,
172173
config_override: "Config | None" = None,
173174
) -> None:
@@ -180,7 +181,9 @@ def __init__(
180181
port = int(port)
181182
elif port is None:
182183
port = self._config["database.port"]
183-
self.conn_info = dict(host=host, port=port, user=user, passwd=password)
184+
if database_name is None:
185+
database_name = self._config.get("database.name")
186+
self.conn_info = dict(host=host, port=port, user=user, passwd=password, database_name=database_name)
184187
if use_tls is not False:
185188
# use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config)
186189
if isinstance(use_tls, dict):
@@ -201,12 +204,27 @@ def __init__(
201204
backend = self._config["database.backend"]
202205
self.adapter = get_adapter(backend)
203206

207+
if database_name and self.adapter.backend == "mysql":
208+
warnings.warn(
209+
"database.name is set but the MySQL backend does not support database selection. "
210+
"This setting only applies to PostgreSQL connections.",
211+
UserWarning,
212+
stacklevel=2,
213+
)
214+
204215
self.connect()
205216
if self.is_connected:
206-
logger.info("DataJoint {version} connected to {user}@{host}:{port}".format(version=__version__, **self.conn_info))
217+
db = self.conn_info.get("database_name")
218+
db_str = f"/{db}" if db else ""
219+
logger.info(
220+
f"DataJoint {__version__} connected to "
221+
f"{self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}{db_str}"
222+
)
207223
self.connection_id = self.adapter.get_connection_id(self._conn)
208224
else:
209-
raise errors.LostConnectionError("Connection failed {user}@{host}:{port}".format(**self.conn_info))
225+
raise errors.LostConnectionError(
226+
f"Connection failed {self.conn_info['user']}@{self.conn_info['host']}:{self.conn_info['port']}"
227+
)
210228
self._in_transaction = False
211229
self.schemas = dict()
212230
self.dependencies = Dependencies(self)
@@ -216,22 +234,33 @@ def __eq__(self, other):
216234

217235
def __repr__(self):
218236
connected = "connected" if self.is_connected else "disconnected"
219-
return "DataJoint connection ({connected}) {user}@{host}:{port}".format(connected=connected, **self.conn_info)
237+
user = self.conn_info["user"]
238+
host = self.conn_info["host"]
239+
port = self.conn_info["port"]
240+
db = self.conn_info.get("database_name")
241+
db_str = f"/{db}" if db else ""
242+
return f"DataJoint connection ({connected}) {user}@{host}:{port}{db_str}"
243+
244+
def _build_connect_kwargs(self, use_tls=None):
245+
"""Build kwargs dict for adapter.connect()."""
246+
kwargs = dict(
247+
host=self.conn_info["host"],
248+
port=self.conn_info["port"],
249+
user=self.conn_info["user"],
250+
password=self.conn_info["passwd"],
251+
charset=self._config["connection.charset"],
252+
use_tls=use_tls if use_tls is not None else self.conn_info.get("ssl"),
253+
)
254+
if self.conn_info.get("database_name"):
255+
kwargs["dbname"] = self.conn_info["database_name"]
256+
return kwargs
220257

221258
def connect(self) -> None:
222259
"""Establish or re-establish connection to the database server."""
223260
with warnings.catch_warnings():
224261
warnings.filterwarnings("ignore", ".*deprecated.*")
225262
try:
226-
# Use adapter to create connection
227-
self._conn = self.adapter.connect(
228-
host=self.conn_info["host"],
229-
port=self.conn_info["port"],
230-
user=self.conn_info["user"],
231-
password=self.conn_info["passwd"],
232-
charset=self._config["connection.charset"],
233-
use_tls=self.conn_info.get("ssl"),
234-
)
263+
self._conn = self.adapter.connect(**self._build_connect_kwargs())
235264
except Exception as ssl_error:
236265
# If SSL fails, retry without SSL (if it was auto-detected)
237266
if self.conn_info.get("ssl_input") is None:
@@ -240,14 +269,7 @@ def connect(self) -> None:
240269
"To require SSL, set use_tls=True explicitly.",
241270
ssl_error,
242271
)
243-
self._conn = self.adapter.connect(
244-
host=self.conn_info["host"],
245-
port=self.conn_info["port"],
246-
user=self.conn_info["user"],
247-
password=self.conn_info["passwd"],
248-
charset=self._config["connection.charset"],
249-
use_tls=False, # Explicitly disable SSL for fallback
250-
)
272+
self._conn = self.adapter.connect(**self._build_connect_kwargs(use_tls=False))
251273
else:
252274
raise
253275
self._is_closed = False # Mark as connected after successful connection

src/datajoint/schemas.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,13 @@ def activate(
172172
self.connection = connection
173173
if self.connection is None:
174174
self.connection = _get_singleton_connection()
175+
if self.connection._config.get("database.database_prefix"):
176+
warnings.warn(
177+
"database_prefix is deprecated and will be removed in DataJoint 2.3. "
178+
"Use database.name to select a PostgreSQL database instead.",
179+
DeprecationWarning,
180+
stacklevel=2,
181+
)
175182
self.database = schema_name
176183
if create_schema is not None:
177184
self.create_schema = create_schema

src/datajoint/settings.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"database.password": "DJ_PASS",
6666
"database.backend": "DJ_BACKEND",
6767
"database.port": "DJ_PORT",
68+
"database.name": "DJ_DATABASE_NAME",
6869
"database.database_prefix": "DJ_DATABASE_PREFIX",
6970
"database.create_tables": "DJ_CREATE_TABLES",
7071
"loglevel": "DJ_LOG_LEVEL",
@@ -196,13 +197,17 @@ class DatabaseSettings(BaseSettings):
196197
description="Database backend: 'mysql' or 'postgresql'",
197198
)
198199
port: int | None = Field(default=None, validation_alias="DJ_PORT")
200+
name: str | None = Field(
201+
default=None,
202+
validation_alias="DJ_DATABASE_NAME",
203+
description="Database name for PostgreSQL connections. Defaults to 'postgres' if not set.",
204+
)
199205
reconnect: bool = True
200206
use_tls: bool | None = Field(default=None, validation_alias="DJ_USE_TLS")
201207
database_prefix: str = Field(
202208
default="",
203209
validation_alias="DJ_DATABASE_PREFIX",
204-
description="Prefix for database/schema names. "
205-
"Not automatically applied; use dj.config.database.database_prefix when creating schemas.",
210+
description="Deprecated. Use database.name instead.",
206211
)
207212
create_tables: bool = Field(
208213
default=True,

tests/unit/test_settings.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,71 @@ def test_similar_prefix_names_allowed(self):
750750
dj.config.stores.update(original_stores)
751751

752752

753+
class TestDatabaseNameConfiguration:
754+
"""Test database.name configuration."""
755+
756+
def test_database_name_default_is_none(self):
757+
"""Database name defaults to None when not configured."""
758+
from datajoint.settings import DatabaseSettings
759+
760+
s = DatabaseSettings()
761+
assert s.name is None
762+
763+
def test_database_name_env_var(self, monkeypatch):
764+
"""DJ_DATABASE_NAME environment variable sets database name."""
765+
from datajoint.settings import DatabaseSettings
766+
767+
monkeypatch.setenv("DJ_DATABASE_NAME", "my_database")
768+
s = DatabaseSettings()
769+
assert s.name == "my_database"
770+
771+
def test_database_name_from_config_file(self, tmp_path, monkeypatch):
772+
"""Load database name from config file."""
773+
import json
774+
775+
from datajoint.settings import Config
776+
777+
config_file = tmp_path / "test_config.json"
778+
config_file.write_text(json.dumps({"database": {"name": "custom_db", "host": "localhost"}}))
779+
780+
monkeypatch.delenv("DJ_DATABASE_NAME", raising=False)
781+
monkeypatch.delenv("DJ_HOST", raising=False)
782+
783+
cfg = Config()
784+
cfg.load(config_file)
785+
assert cfg.database.name == "custom_db"
786+
787+
def test_database_name_dict_access(self):
788+
"""Dict-style access reads and writes database name."""
789+
original = dj.config.database.name
790+
try:
791+
dj.config.database.name = "test_db"
792+
assert dj.config["database.name"] == "test_db"
793+
finally:
794+
dj.config.database.name = original
795+
796+
def test_database_name_override_context_manager(self):
797+
"""Override context manager temporarily sets database name."""
798+
original = dj.config.database.name
799+
with dj.config.override(database__name="override_db"):
800+
assert dj.config.database.name == "override_db"
801+
assert dj.config.database.name == original
802+
803+
def test_database_prefix_empty_no_warning(self):
804+
"""Empty database_prefix does not emit DeprecationWarning at config load."""
805+
import warnings
806+
807+
from datajoint.settings import DatabaseSettings
808+
809+
with warnings.catch_warnings(record=True) as w:
810+
warnings.simplefilter("always")
811+
DatabaseSettings()
812+
deprecation_warnings = [
813+
x for x in w if issubclass(x.category, DeprecationWarning) and "database_prefix" in str(x.message)
814+
]
815+
assert len(deprecation_warnings) == 0
816+
817+
753818
class TestBackendConfiguration:
754819
"""Test database backend configuration and port auto-detection."""
755820

0 commit comments

Comments
 (0)