@@ -187,15 +187,45 @@ def df_html(df: DataFrame) -> str:
187187
188188
189189@_log_exceptions
190- def _initialize_spark_connect_session ():
190+ def _initialize_spark_connect_session_grpc ():
191191 import os
192192 from dbruntime .spark_connection import get_and_configure_uds_spark
193193 os .environ ["SPARK_REMOTE" ] = "unix:///databricks/sparkconnect/grpc.sock"
194194 spark = get_and_configure_uds_spark ()
195195 globals ()["spark" ] = spark
196196
197197
198+ @_log_exceptions
199+ def _initialize_spark_connect_session_dbconnect ():
200+ import IPython
201+ from databricks .connect import DatabricksSession
202+ user_ns = getattr (IPython .get_ipython (), "user_ns" , {})
203+ existing_session = getattr (user_ns , "spark" , None )
204+ if existing_session is not None and _is_spark_connect (existing_session ):
205+ return
206+ try :
207+ # Clear the existing local spark session, otherwise DatabricksSession will re-use it.
208+ user_ns ["spark" ] = None
209+ globals ()["spark" ] = None
210+ # DatabricksSession will use the existing env vars for the connection.
211+ spark_session = DatabricksSession .builder .getOrCreate ()
212+ user_ns ["spark" ] = spark_session
213+ globals ()["spark" ] = spark_session
214+ except Exception as e :
215+ user_ns ["spark" ] = existing_session
216+ globals ()["spark" ] = existing_session
217+ raise e
218+
219+
220+ def _is_spark_connect (session ) -> bool :
221+ try :
222+ from pyspark .sql .connect .session import SparkSession as ConnectSparkSession
223+ return isinstance (session , ConnectSparkSession )
224+ except ImportError :
225+ return False
226+
227+
198228_register_magics ()
199229_register_formatters ()
200230_register_runtime_hooks ()
201- _initialize_spark_connect_session ()
231+ _initialize_spark_connect_session_dbconnect ()
0 commit comments