@@ -168,6 +168,7 @@ def __init__(
168168 port : int | None = None ,
169169 use_tls : bool | dict | None = None ,
170170 * ,
171+ database_name : str | None = None ,
171172 backend : str | None = None ,
172173 config_override : "Config | None" = None ,
173174 ) -> None :
@@ -180,7 +181,9 @@ def __init__(
180181 port = int (port )
181182 elif port is None :
182183 port = self ._config ["database.port" ]
183- self .conn_info = dict (host = host , port = port , user = user , passwd = password )
184+ if database_name is None :
185+ database_name = self ._config .get ("database.name" )
186+ self .conn_info = dict (host = host , port = port , user = user , passwd = password , database_name = database_name )
184187 if use_tls is not False :
185188 # use_tls can be: None (auto-detect), True (enable), False (disable), or dict (custom config)
186189 if isinstance (use_tls , dict ):
@@ -201,12 +204,27 @@ def __init__(
201204 backend = self ._config ["database.backend" ]
202205 self .adapter = get_adapter (backend )
203206
207+ if database_name and self .adapter .backend == "mysql" :
208+ warnings .warn (
209+ "database.name is set but the MySQL backend does not support database selection. "
210+ "This setting only applies to PostgreSQL connections." ,
211+ UserWarning ,
212+ stacklevel = 2 ,
213+ )
214+
204215 self .connect ()
205216 if self .is_connected :
206- logger .info ("DataJoint {version} connected to {user}@{host}:{port}" .format (version = __version__ , ** self .conn_info ))
217+ db = self .conn_info .get ("database_name" )
218+ db_str = f"/{ db } " if db else ""
219+ logger .info (
220+ f"DataJoint { __version__ } connected to "
221+ f"{ self .conn_info ['user' ]} @{ self .conn_info ['host' ]} :{ self .conn_info ['port' ]} { db_str } "
222+ )
207223 self .connection_id = self .adapter .get_connection_id (self ._conn )
208224 else :
209- raise errors .LostConnectionError ("Connection failed {user}@{host}:{port}" .format (** self .conn_info ))
225+ raise errors .LostConnectionError (
226+ f"Connection failed { self .conn_info ['user' ]} @{ self .conn_info ['host' ]} :{ self .conn_info ['port' ]} "
227+ )
210228 self ._in_transaction = False
211229 self .schemas = dict ()
212230 self .dependencies = Dependencies (self )
@@ -216,22 +234,33 @@ def __eq__(self, other):
216234
217235 def __repr__ (self ):
218236 connected = "connected" if self .is_connected else "disconnected"
219- return "DataJoint connection ({connected}) {user}@{host}:{port}" .format (connected = connected , ** self .conn_info )
237+ user = self .conn_info ["user" ]
238+ host = self .conn_info ["host" ]
239+ port = self .conn_info ["port" ]
240+ db = self .conn_info .get ("database_name" )
241+ db_str = f"/{ db } " if db else ""
242+ return f"DataJoint connection ({ connected } ) { user } @{ host } :{ port } { db_str } "
243+
244+ def _build_connect_kwargs (self , use_tls = None ):
245+ """Build kwargs dict for adapter.connect()."""
246+ kwargs = dict (
247+ host = self .conn_info ["host" ],
248+ port = self .conn_info ["port" ],
249+ user = self .conn_info ["user" ],
250+ password = self .conn_info ["passwd" ],
251+ charset = self ._config ["connection.charset" ],
252+ use_tls = use_tls if use_tls is not None else self .conn_info .get ("ssl" ),
253+ )
254+ if self .conn_info .get ("database_name" ):
255+ kwargs ["dbname" ] = self .conn_info ["database_name" ]
256+ return kwargs
220257
221258 def connect (self ) -> None :
222259 """Establish or re-establish connection to the database server."""
223260 with warnings .catch_warnings ():
224261 warnings .filterwarnings ("ignore" , ".*deprecated.*" )
225262 try :
226- # Use adapter to create connection
227- self ._conn = self .adapter .connect (
228- host = self .conn_info ["host" ],
229- port = self .conn_info ["port" ],
230- user = self .conn_info ["user" ],
231- password = self .conn_info ["passwd" ],
232- charset = self ._config ["connection.charset" ],
233- use_tls = self .conn_info .get ("ssl" ),
234- )
263+ self ._conn = self .adapter .connect (** self ._build_connect_kwargs ())
235264 except Exception as ssl_error :
236265 # If SSL fails, retry without SSL (if it was auto-detected)
237266 if self .conn_info .get ("ssl_input" ) is None :
@@ -240,14 +269,7 @@ def connect(self) -> None:
240269 "To require SSL, set use_tls=True explicitly." ,
241270 ssl_error ,
242271 )
243- self ._conn = self .adapter .connect (
244- host = self .conn_info ["host" ],
245- port = self .conn_info ["port" ],
246- user = self .conn_info ["user" ],
247- password = self .conn_info ["passwd" ],
248- charset = self ._config ["connection.charset" ],
249- use_tls = False , # Explicitly disable SSL for fallback
250- )
272+ self ._conn = self .adapter .connect (** self ._build_connect_kwargs (use_tls = False ))
251273 else :
252274 raise
253275 self ._is_closed = False # Mark as connected after successful connection
0 commit comments