Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions providers/apache/spark/docs/operators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ independently on the cluster. If the Airflow worker dies while the Spark job is
Airflow loses track of it and the behaviour to submit a brand new job would be wasting
the compute already done or even cause conflicts if the Spark job itself is not designed to be idempotent.

Now, the ``SparkSubmitOperator`` solves this by persisting the driver ID to ``task_state`` immediately after
Now, the ``SparkSubmitOperator`` solves this by persisting the driver ID to ``task_store`` immediately after
submission. On retry, it reads the ID back and reconnects to the already-running driver instead of
resubmitting.

Expand All @@ -212,7 +212,7 @@ The reconnection polling calls the Spark standalone REST API
See :doc:`connections/spark-submit` for how to configure these fields.

.. note::
Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` support). On earlier
Crash recovery in cluster mode requires Airflow 3.3+ (``task_store`` support). On earlier
versions the operator falls back to the previous behavior of always submitting fresh.

Tracking driver status via Kubernetes API
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# ResumableJobMixin does not exist in Airflow 2, so we need to add a stub to make it
# behave as before
class ResumableJobMixin: # type: ignore[no-redef]
"""Airflow 2 stub — no task_state, always submits fresh."""
"""Airflow 2 stub — no task_store, always submits fresh."""

external_id_key: str = "remote_job_id"

Expand Down Expand Up @@ -264,7 +264,7 @@ def execute(self, context: Context) -> None:
if hook._should_track_driver_status:
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence.
# reconnect_on_retry=False: still submit-and-poll, just skip task_store persistence.
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
Expand All @@ -284,7 +284,7 @@ def execute(self, context: Context) -> None:
hook._validate_yarn_track_via_rm_api_config()
if self.reconnect_on_retry:
return self.execute_resumable(context)
# reconnect_on_retry=False: still submit-and-poll, just skip task_state persistence.
# reconnect_on_retry=False: still submit-and-poll, just skip task_store persistence.
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def test_inject_openlineage_simple_config_wrong_transport_to_spark(
}


class FakeTaskState:
class FakeTaskStore:
"""In-memory task state for tests."""

def __init__(self, stored: dict[str, str] | None = None):
Expand Down Expand Up @@ -528,7 +528,7 @@ def test_cluster_mode_first_run_persists_id_before_polling(self):
operator._hook = self._make_hook(should_track=True)
operator._hook.submit.return_value = "driver-001"

task_store = FakeTaskState()
task_store = FakeTaskStore()
persisted_before_poll = []

def track_poll(external_id, context):
Expand All @@ -555,7 +555,7 @@ def test_retry_behaviour_based_on_prior_driver_status(self, prior_status, expect
operator = self._make_operator()
operator._hook = self._make_hook(should_track=True)
operator._hook.submit.return_value = "driver-new"
task_store = FakeTaskState({"spark_job_id": "driver-001"})
task_store = FakeTaskStore({"spark_job_id": "driver-001"})

operator.get_job_status = lambda external_id, context: prior_status
polled = []
Expand Down Expand Up @@ -590,7 +590,7 @@ def test_reconnect_on_retry_false_submits_fresh_and_polls(self):
operator = self._make_operator(reconnect_on_retry=False)
operator._hook = self._make_hook(should_track=True)
operator._hook.submit.return_value = "driver-new"
task_store = FakeTaskState({"spark_job_id": "driver-old"})
task_store = FakeTaskStore({"spark_job_id": "driver-old"})
polled = []
operator.poll_until_complete = lambda external_id, context: polled.append(external_id)

Expand Down Expand Up @@ -733,7 +733,7 @@ def test_yarn_first_run_persists_app_id_before_polling(self):
operator._hook._yarn_application_id = "application_1234_0001"
operator._hook.submit.return_value = None

task_store = FakeTaskState()
task_store = FakeTaskStore()
persisted_before_poll = []

def track_poll(external_id, context):
Expand All @@ -747,7 +747,7 @@ def track_poll(external_id, context):
def test_yarn_retry_reconnects_to_running_app(self):
operator = self._make_operator()
operator._hook = self._make_hook(is_yarn_cluster=True)
task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
task_store = FakeTaskStore({"spark_job_id": "application_1234_0001"})

operator.get_job_status = lambda external_id, context: "RUNNING"
polled = []
Expand All @@ -761,7 +761,7 @@ def test_yarn_retry_reconnects_to_running_app(self):
def test_yarn_retry_skips_already_succeeded_app(self):
operator = self._make_operator()
operator._hook = self._make_hook(is_yarn_cluster=True)
task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
task_store = FakeTaskStore({"spark_job_id": "application_1234_0001"})

operator.get_job_status = lambda external_id, context: "SUCCEEDED"

Expand All @@ -775,7 +775,7 @@ def test_yarn_retry_resubmits_after_failed_app(self):
operator._hook._conf = {}
operator._hook._yarn_application_id = "application_1234_0002"
operator._hook.submit.return_value = None
task_store = FakeTaskState({"spark_job_id": "application_1234_0001"})
task_store = FakeTaskStore({"spark_job_id": "application_1234_0001"})

operator.get_job_status = lambda external_id, context: "FAILED"
polled = []
Expand Down
Loading