Skip to content

Commit 43f370b

Browse files
committed
Use DatabricksSession for Spark Connect session initialization
1 parent dbf7277 commit 43f370b

1 file changed

Lines changed: 32 additions & 2 deletions

File tree

experimental/ssh/internal/server/jupyter-init.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,45 @@ def df_html(df: DataFrame) -> str:
187187

188188

189189
@_log_exceptions
190-
def _initialize_spark_connect_session():
190+
def _initialize_spark_connect_session_grpc():
191191
import os
192192
from dbruntime.spark_connection import get_and_configure_uds_spark
193193
os.environ["SPARK_REMOTE"] = "unix:///databricks/sparkconnect/grpc.sock"
194194
spark = get_and_configure_uds_spark()
195195
globals()["spark"] = spark
196196

197197

198+
@_log_exceptions
199+
def _initialize_spark_connect_session_dbconnect():
200+
import IPython
201+
from databricks.connect import DatabricksSession
202+
user_ns = getattr(IPython.get_ipython(), "user_ns", {})
203+
existing_session = getattr(user_ns, "spark", None)
204+
if existing_session is not None and _is_spark_connect(existing_session):
205+
return
206+
try:
207+
# Clear the existing local spark session, otherwise DatabricksSession will re-use it.
208+
user_ns["spark"] = None
209+
globals()["spark"] = None
210+
# DatabricksSession will use the existing env vars for the connection.
211+
spark_session = DatabricksSession.builder.getOrCreate()
212+
user_ns["spark"] = spark_session
213+
globals()["spark"] = spark_session
214+
except Exception as e:
215+
user_ns["spark"] = existing_session
216+
globals()["spark"] = existing_session
217+
raise e
218+
219+
220+
def _is_spark_connect(session) -> bool:
221+
try:
222+
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession
223+
return isinstance(session, ConnectSparkSession)
224+
except ImportError:
225+
return False
226+
227+
198228
_register_magics()
199229
_register_formatters()
200230
_register_runtime_hooks()
201-
_initialize_spark_connect_session()
231+
_initialize_spark_connect_session_dbconnect()

0 commit comments

Comments
 (0)