diff --git a/distributed/deploy/subprocess.py b/distributed/deploy/subprocess.py index 0b4323ba107..c2284a138b3 100644 --- a/distributed/deploy/subprocess.py +++ b/distributed/deploy/subprocess.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import asyncio import copy import json @@ -15,13 +16,85 @@ from distributed.compatibility import WINDOWS from distributed.deploy.spec import ProcessInterface, SpecCluster from distributed.deploy.utils import nprocesses_nthreads -from distributed.scheduler import Scheduler from distributed.worker_memory import parse_memory_limit logger = logging.getLogger(__name__) -class SubprocessWorker(ProcessInterface): +class Subprocess(ProcessInterface, abc.ABC): + process: asyncio.subprocess.Process | None + + def __init__(self): + if WINDOWS: + # FIXME: distributed#7434 + raise RuntimeError("Subprocess does not support Windows.") + self.process = None + super().__init__() + + async def start(self) -> None: + await self._start() + await super().start() + + @abc.abstractmethod + async def _start(self) -> None: + """Start the subprocess""" + + async def close(self) -> None: + if self.process and self.process.returncode is None: + for child in psutil.Process(self.process.pid).children(recursive=True): + child.kill() + self.process.kill() + await self.process.communicate() + self.process = None + await super().close() + + +class SubprocessScheduler(Subprocess): + """A local Dask scheduler running in a dedicated subprocess + + Parameters + ---------- + scheduler_kwargs: + Keywords to pass on to the ``Scheduler`` class constructor + """ + + scheduler_kwargs: dict + address: str | None + + def __init__( + self, + scheduler_kwargs: dict | None = None, + ): + self.scheduler_kwargs = scheduler_kwargs or {} + super().__init__() + + async def _start(self): + cmd = [ + "dask", + "spec", + "--spec", + json.dumps( + {"cls": "distributed.Scheduler", "opts": {**self.scheduler_kwargs}} + ), + ] + logger.info(" ".join(cmd)) + self.process = await asyncio.create_subprocess_exec( + *cmd, + stderr=asyncio.subprocess.PIPE, + ) + + while True: + line = (await self.process.stderr.readline()).decode() + if not line.strip(): + raise RuntimeError("Scheduler failed to start") + logger.info(line.strip()) + if "Scheduler at" in line: + self.address = line.split("Scheduler at:")[1].strip() + break + logger.debug(line) + + +class SubprocessWorker(Subprocess): """A local Dask worker running in a dedicated subprocess Parameters @@ -36,11 +109,10 @@ class SubprocessWorker(ProcessInterface): Keywords to pass on to the ``Worker`` class constructor """ + name: str | None scheduler: str worker_class: str worker_kwargs: dict - name: str | None - process: asyncio.subprocess.Process | None def __init__( self, @@ -49,34 +121,22 @@ def __init__( name: str | None = None, worker_kwargs: dict | None = None, ) -> None: - if WINDOWS: - # FIXME: distributed#7434 - raise RuntimeError("SubprocessWorker does not support Windows.") + self.name = name self.scheduler = scheduler self.worker_class = worker_class - self.name = name self.worker_kwargs = copy.copy(worker_kwargs or {}) - self.process = None super().__init__() - async def start(self) -> None: - self.process = await asyncio.create_subprocess_exec( + async def _start(self) -> None: + cmd = [ "dask", "spec", self.scheduler, "--spec", - json.dumps({0: {"cls": self.worker_class, "opts": {**self.worker_kwargs}}}), - ) - await super().start() - - async def close(self) -> None: - if self.process and self.process.returncode is None: - for child in psutil.Process(self.process.pid).children(recursive=True): - child.kill() - self.process.kill() - await self.process.wait() - self.process = None - await super().close() + json.dumps({"cls": self.worker_class, "opts": {**self.worker_kwargs}}), + ] + logger.info(" ".join(cmd)) + self.process = await asyncio.create_subprocess_exec(*cmd) def SubprocessCluster( @@ -91,10 +151,9 @@ def SubprocessCluster( silence_logs: int = logging.WARN, **kwargs: Any, ) -> SpecCluster: - """Create in-process scheduler and workers running in dedicated subprocesses + """Create a scheduler and workers that run in dedicated subprocesses - This creates a "cluster" of a scheduler running in the current process and - workers running in dedicated subprocesses. + This creates a "cluster" of a scheduler and workers running in dedicated subprocesses. .. warning:: @@ -178,7 +237,12 @@ def SubprocessCluster( worker_kwargs, ) - scheduler = {"cls": Scheduler, "options": scheduler_kwargs} + scheduler = { + "cls": SubprocessScheduler, + "options": { + "scheduler_kwargs": scheduler_kwargs, + }, + } worker = { "cls": SubprocessWorker, "options": {"worker_class": worker_class, "worker_kwargs": worker_kwargs}, diff --git a/distributed/deploy/tests/test_subprocess.py b/distributed/deploy/tests/test_subprocess.py index be7a11ce322..3382fef88a8 100644 --- a/distributed/deploy/tests/test_subprocess.py +++ b/distributed/deploy/tests/test_subprocess.py @@ -4,7 +4,11 @@ from distributed import Client from distributed.compatibility import WINDOWS -from distributed.deploy.subprocess import SubprocessCluster, SubprocessWorker +from distributed.deploy.subprocess import ( + SubprocessCluster, + SubprocessScheduler, + SubprocessWorker, +) from distributed.utils_test import gen_test @@ -53,7 +57,6 @@ async def test_scale_up_and_down(): cluster.scale(2) await c.wait_for_workers(2) assert len(cluster.workers) == 2 - assert len(cluster.scheduler.workers) == 2 cluster.scale(1) await cluster @@ -61,6 +64,14 @@ async def test_scale_up_and_down(): assert len(cluster.workers) == 1 +@pytest.mark.skipif(WINDOWS, reason="distributed#7434") +@gen_test() +async def test_raise_if_scheduler_fails_to_start(): + with pytest.raises(RuntimeError, match="Scheduler failed to start"): + async with SubprocessCluster(scheduler_port=-1, asynchronous=True): + pass + + @pytest.mark.skipif( not WINDOWS, reason="Windows-specific error testing (distributed#7434)" ) @@ -68,5 +79,8 @@ def test_raise_on_windows(): with pytest.raises(RuntimeError, match="not support Windows"): SubprocessCluster() + with pytest.raises(RuntimeError, match="not support Windows"): + SubprocessScheduler() + with pytest.raises(RuntimeError, match="not support Windows"): SubprocessWorker(scheduler="tcp://127.0.0.1:8786")