diff --git a/bin/find-spark-home b/bin/find-spark-home index 462b538b00a04..7c70d588291dc 100755 --- a/bin/find-spark-home +++ b/bin/find-spark-home @@ -23,7 +23,7 @@ FIND_SPARK_HOME_PYTHON_SCRIPT="$(cd "$(dirname "$0")"; pwd)/find_spark_home.py" # Short circuit if the user already has this set. if [ ! -z "${SPARK_HOME}" ]; then - exit 0 + return 0 2>/dev/null || exit 0 elif [ ! -f "$FIND_SPARK_HOME_PYTHON_SCRIPT" ]; then # If we are not in the same directory as find_spark_home.py we are not pip installed so we don't # need to search the different Python directories for a Spark installation. diff --git a/python/pyspark/tests/test_util.py b/python/pyspark/tests/test_util.py index 9fd0135eff4eb..f84f13c923213 100644 --- a/python/pyspark/tests/test_util.py +++ b/python/pyspark/tests/test_util.py @@ -16,6 +16,8 @@ # import gc import os +import subprocess +import sys import time import unittest from unittest.mock import patch @@ -89,6 +91,26 @@ def test_find_spark_home(self): finally: os.environ["SPARK_HOME"] = origin + @unittest.skipIf(sys.platform == "win32", "find-spark-home is a bash script") + def test_find_spark_home_script_sourceable(self): + # SPARK-54434: bin/find-spark-home is documented to be sourced. + # When `SPARK_HOME` is already set, the script short-circuits. + # The short-circuit must use `return 0` rather than `exit 0`, otherwise + # sourcing the script terminates the caller's shell session. + script_path = os.path.join(os.environ["SPARK_HOME"], "bin", "find-spark-home") + # Source the script twice and print a marker afterwards. If the script + # calls `exit 0` while sourced, the outer shell terminates and the + # marker is never emitted. + cmd = ( + f"export SPARK_HOME=/some/value && " + f"source {script_path} && " + f"source {script_path} && " + f"echo SOURCED_OK" + ) + completed = subprocess.run(["bash", "-c", cmd], capture_output=True, text=True, check=False) + self.assertEqual(completed.returncode, 0, msg=completed.stderr) + self.assertIn("SOURCED_OK", completed.stdout) + def test_timeout_decorator(self): @timeout(1) def timeout_func():