diff --git a/changes/3925.feature.md b/changes/3925.feature.md new file mode 100644 index 0000000000..ed07be309c --- /dev/null +++ b/changes/3925.feature.md @@ -0,0 +1 @@ +Add `zarr.abc.store.Store.get_ranges` for concurrent, coalesced multi-range reads from a single key. The method is defined on the `Store` ABC with a default implementation built on `Store.get`, so every store inherits a working version; stores with native multi-range backends (e.g. `FsspecStore`) can override for efficiency. Coalescing knobs (`max_concurrency`, `max_gap_bytes`, `max_coalesced_bytes`) are passed as keyword arguments to `get_ranges`. Failures from underlying fetches surface as a `BaseExceptionGroup` (PEP 654); callers should use `except*` to filter for specific exception types such as `FileNotFoundError`. diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index 600df17ee5..3247649f10 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -4,13 +4,14 @@ import json from abc import ABC, abstractmethod from dataclasses import dataclass +from functools import partial from itertools import starmap from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable from zarr.core.sync import sync if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterator, Iterable + from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence from types import TracebackType from typing import Any, Self @@ -616,6 +617,66 @@ async def _get_many( for req in requests: yield (req[0], await self.get(*req)) + async def get_ranges( + self, + key: str, + byte_ranges: Sequence[ByteRequest | None], + *, + prototype: BufferPrototype, + max_concurrency: int = 10, + max_gap_bytes: int = 1 << 20, # 1 MiB + max_coalesced_bytes: int = 16 << 20, # 16 MiB + ) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]: + """Read many byte ranges from `key`. + + Yields one batch per underlying I/O operation, each a sequence of + `(input_index, Buffer | None)` tuples. Batches across yields arrive in + completion order, not input order. The default implementation built + into `Store` runs the coalescer over `self.get`, so subclasses get a + working implementation for free; stores that have a more efficient + backend (e.g. ranged HTTP, S3 byte-range fetches) should override. + + Parameters + ---------- + key + Storage key to read from. + byte_ranges + Input ranges. `None` means "the whole value". + prototype + Buffer prototype, forwarded to `self.get`. + max_concurrency + Maximum number of merged fetches in flight at once. + max_gap_bytes + Two `RangeByteRequest`s separated by at most this many bytes may + be merged into one fetch. + max_coalesced_bytes + Upper bound on the size of a single merged fetch. + + Raises + ------ + BaseExceptionGroup + Failures from underlying fetches are reported as a + `BaseExceptionGroup` (PEP 654) and should be handled with + `except*`. Inner exceptions include `FileNotFoundError` if any + fetch returns `None` (i.e. `key` is absent), and any exception + raised by `self.get` for the corresponding range. Pending + fetches are cancelled as soon as one task fails, so the group + typically contains a single non-`CancelledError` exception even + under high concurrency. + """ + # Local import: zarr.core._coalesce imports symbols from this module. + from zarr.core._coalesce import coalesced_get + + fetch = partial(self.get, key, prototype) + async for group in coalesced_get( + fetch, + byte_ranges, + max_concurrency=max_concurrency, + max_gap_bytes=max_gap_bytes, + max_coalesced_bytes=max_coalesced_bytes, + ): + yield group + async def getsize(self, key: str) -> int: """ Return the size, in bytes, of a value in a Store. diff --git a/src/zarr/core/_coalesce.py b/src/zarr/core/_coalesce.py new file mode 100644 index 0000000000..08100e822f --- /dev/null +++ b/src/zarr/core/_coalesce.py @@ -0,0 +1,227 @@ +# src/zarr/core/_coalesce.py +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, NamedTuple + +from zarr.abc.store import RangeByteRequest + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence + + from zarr.abc.store import ByteRequest + from zarr.core.buffer import Buffer + + +class _WorkerCtx(NamedTuple): + """Shared state passed to the per-task worker coroutines. + + Bundling these lets the workers declare their dependencies as one + parameter instead of capturing them implicitly via closure. + """ + + fetch: Callable[[ByteRequest | None], Awaitable[Buffer | None]] + semaphore: asyncio.Semaphore + + +async def _fetch_single( + ctx: _WorkerCtx, idx: int, req: ByteRequest | None +) -> Sequence[tuple[int, Buffer | None]]: + """Fetch one byte range. Raises FileNotFoundError if the key is absent.""" + async with ctx.semaphore: + buf = await ctx.fetch(req) + if buf is None: + raise FileNotFoundError + return ((idx, buf),) + + +async def _fetch_group( + ctx: _WorkerCtx, members: list[tuple[int, RangeByteRequest]] +) -> Sequence[tuple[int, Buffer | None]]: + """Fetch one merged byte range and slice it back into per-input buffers. + + `members` must already be sorted by `start`; callers in this module + build it from the sorted mergeable list. Raises `FileNotFoundError` + if the key is absent. + """ + if len(members) == 1: + solo_idx, solo_req = members[0] + return await _fetch_single(ctx, solo_idx, solo_req) + + start = members[0][1].start + end = max(r.end for _, r in members) + async with ctx.semaphore: + big = await ctx.fetch(RangeByteRequest(start, end)) + if big is None: + raise FileNotFoundError + sliced = [(idx, big[r.start - start : r.end - start]) for idx, r in members] + return tuple(sliced) + + +def coalesce_ranges( + byte_ranges: Sequence[ByteRequest | None], + *, + max_gap_bytes: int, + max_coalesced_bytes: int, +) -> tuple[ + list[list[tuple[int, RangeByteRequest]]], + list[tuple[int, ByteRequest | None]], +]: + """Plan a set of byte-range fetches: which inputs merge, which stand alone. + + Pure (no I/O). The result is the I/O plan a caller would execute: each + group corresponds to one fetch of a coalesced byte range, and each + uncoalescable item corresponds to one fetch of the original request. + + All tuning knobs are required keyword arguments. `Store.get_ranges` is + the public entry point and owns the canonical default values; this + function takes them explicitly to avoid duplicating policy. + + Parameters + ---------- + byte_ranges + Input ranges. `None` means "the whole value". + max_gap_bytes + Two `RangeByteRequest`s separated by at most this many bytes may be + merged into one fetch. + max_coalesced_bytes + Upper bound on the size of a single merged fetch. + + Returns + ------- + groups + List of merged groups. Each group is a list of + `(input_index, RangeByteRequest)` pairs sorted by `start`. A + single-element group represents a `RangeByteRequest` that did not + merge with any neighbor. + uncoalescable + List of `(input_index, request)` pairs for inputs that are not + `RangeByteRequest` (`OffsetByteRequest`, `SuffixByteRequest`, + `None`). Indices are preserved from the input order. + + Notes + ----- + Only `RangeByteRequest` inputs participate in coalescing. Two ranges + merge when both: their gap (next `start` minus current group's running + `end`) is `<= max_gap_bytes`, and the resulting merged span is + `<= max_coalesced_bytes`. + """ + indexed: list[tuple[int, ByteRequest | None]] = list(enumerate(byte_ranges)) + mergeable: list[tuple[int, RangeByteRequest]] = [ + (i, r) for i, r in indexed if isinstance(r, RangeByteRequest) + ] + uncoalescable: list[tuple[int, ByteRequest | None]] = [ + (i, r) for i, r in indexed if not isinstance(r, RangeByteRequest) + ] + + # Sort mergeables by start offset, then merge. Track running start/end of the + # current group so each merge step is O(1) instead of O(group size). + mergeable.sort(key=lambda pair: pair[1].start) + groups: list[list[tuple[int, RangeByteRequest]]] = [] + group_start = 0 + group_end = 0 + for pair in mergeable: + _i, r = pair + if groups and r.start - group_end <= max_gap_bytes: + prospective_end = max(group_end, r.end) + if prospective_end - group_start <= max_coalesced_bytes: + groups[-1].append(pair) + group_end = prospective_end + continue + groups.append([pair]) + group_start = r.start + group_end = r.end + + return groups, uncoalescable + + +async def coalesced_get( + fetch: Callable[[ByteRequest | None], Awaitable[Buffer | None]], + byte_ranges: Sequence[ByteRequest | None], + *, + max_concurrency: int, + max_gap_bytes: int, + max_coalesced_bytes: int, +) -> AsyncGenerator[Sequence[tuple[int, Buffer | None]]]: + """Read many byte ranges through `fetch` with coalescing and concurrency. + + Nearby ranges are merged into a single underlying I/O, and merged fetches + are run concurrently. Each yield corresponds to exactly one underlying I/O + operation: a sequence of `(input_index, result)` tuples for all input + ranges served by that I/O. Tuples within a yielded sequence are ordered by + start offset. Yields across groups are in completion order, not input + order. + + All tuning knobs are required keyword arguments. `Store.get_ranges` is + the public entry point and owns the canonical default values; this + function takes them explicitly to avoid duplicating policy. + + Parameters + ---------- + fetch + Callable that reads one byte range and returns a `Buffer` (or `None` + if the underlying key does not exist). Typically constructed via + `functools.partial(store.get, key, prototype)`. + byte_ranges + Input ranges. `None` means "the whole value". + max_concurrency + Maximum number of merged fetches in flight at once. + max_gap_bytes + Forwarded to `coalesce_ranges`. + max_coalesced_bytes + Forwarded to `coalesce_ranges`. + + Yields + ------ + Sequence[tuple[int, Buffer | None]] + Per-I/O batch of `(input_index, result)` tuples. + + Notes + ----- + - Only `RangeByteRequest` inputs are coalesced. `OffsetByteRequest`, + `SuffixByteRequest`, and `None` are each treated as uncoalescable + (one fetch, one single-tuple yield per input). + - Failures from underlying fetches surface as a `BaseExceptionGroup` + (PEP 654). Inner exceptions include `FileNotFoundError` if a fetch + returns `None`, plus any exception `fetch` raises. Pending fetches are + cancelled as soon as one task fails, so the group typically contains a + single non-`CancelledError` exception even under high concurrency. + - Groups completed before the failure remain observable on the yields + preceding the raise. + - `GeneratorExit` raised by `aclose()` is filtered out so the iterator + closes cleanly; callers don't see a group containing only it. + """ + if not byte_ranges: + return + + groups, singles = coalesce_ranges( + byte_ranges, + max_gap_bytes=max_gap_bytes, + max_coalesced_bytes=max_coalesced_bytes, + ) + + ctx = _WorkerCtx(fetch=fetch, semaphore=asyncio.Semaphore(max_concurrency)) + + # Launch all work as tasks. The semaphore bounds actual I/O concurrency. + # TaskGroup wraps task exceptions in BaseExceptionGroup; we propagate the + # group unchanged as part of the public contract (callers handle batch + # failures via `except*` / PEP 654). GeneratorExit (raised when the + # consumer calls aclose()) is filtered out so close completes cleanly. + try: + async with asyncio.TaskGroup() as tg: + tasks = [ + *(tg.create_task(_fetch_group(ctx, group)) for group in groups), + *(tg.create_task(_fetch_single(ctx, i, single)) for i, single in singles), + ] + + for fut in asyncio.as_completed(tasks): + yield await fut + except BaseExceptionGroup as eg: + # Strip GeneratorExits (consumer aclose()) and propagate whatever + # remains. `split` is used instead of `subgroup` because the latter + # short-circuits on the group object itself, returning the unchanged + # group when a predicate lambda happens to be true for the wrapper. + _, rest = eg.split(GeneratorExit) + if rest is None: + return # only GeneratorExits — clean close + raise rest from None diff --git a/src/zarr/storage/_wrapper.py b/src/zarr/storage/_wrapper.py index d8ecfa6d45..37aeb8166f 100644 --- a/src/zarr/storage/_wrapper.py +++ b/src/zarr/storage/_wrapper.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, cast if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterator, Iterable + from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Sequence from types import TracebackType from typing import Any, Self @@ -103,6 +103,32 @@ async def get_partial_values( ) -> list[Buffer | None]: return await self._store.get_partial_values(prototype, key_ranges) + async def get_ranges( + self, + key: str, + byte_ranges: Sequence[ByteRequest | None], + *, + prototype: BufferPrototype, + max_concurrency: int | None = None, + max_gap_bytes: int | None = None, + max_coalesced_bytes: int | None = None, + ) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]: + """Forward `get_ranges` to the wrapped store. + + Default values for the coalescing kwargs are not declared here; the + wrapped store decides them. `None` means "don't override the wrapped + store's default". + """ + kwargs: dict[str, int] = {} + if max_concurrency is not None: + kwargs["max_concurrency"] = max_concurrency + if max_gap_bytes is not None: + kwargs["max_gap_bytes"] = max_gap_bytes + if max_coalesced_bytes is not None: + kwargs["max_coalesced_bytes"] = max_coalesced_bytes + async for group in self._store.get_ranges(key, byte_ranges, prototype=prototype, **kwargs): + yield group + async def exists(self, key: str) -> bool: return await self._store.exists(key) diff --git a/tests/test_coalesce.py b/tests/test_coalesce.py new file mode 100644 index 0000000000..cb8ff29ec7 --- /dev/null +++ b/tests/test_coalesce.py @@ -0,0 +1,674 @@ +# tests/test_coalesce.py +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import pytest + +from zarr.abc.store import ( + ByteRequest, + OffsetByteRequest, + RangeByteRequest, + SuffixByteRequest, +) +from zarr.core._coalesce import ( + coalesce_ranges, + coalesced_get, +) +from zarr.core.buffer import Buffer, default_buffer_prototype + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Callable, Mapping, Sequence + + +def _buf(data: bytes) -> Buffer: + return default_buffer_prototype().buffer.from_bytes(data) + + +@dataclass +class FakeFetch: + """Records every call and serves canned bytes from an in-memory blob.""" + + blob: bytes + key_exists: bool = True + raise_on: Callable[[ByteRequest | None], bool] | None = None + calls: list[ByteRequest | None] = field(default_factory=list) + + async def __call__(self, byte_range: ByteRequest | None) -> Buffer | None: + self.calls.append(byte_range) + if not self.key_exists: + return None + if self.raise_on is not None and self.raise_on(byte_range): + raise OSError("injected") + if byte_range is None: + return _buf(self.blob) + if isinstance(byte_range, RangeByteRequest): + return _buf(self.blob[byte_range.start : byte_range.end]) + if isinstance(byte_range, OffsetByteRequest): + return _buf(self.blob[byte_range.offset :]) + if isinstance(byte_range, SuffixByteRequest): + return _buf(self.blob[-byte_range.suffix :]) + raise AssertionError(f"unknown byte_range {byte_range!r}") + + +async def _collect( + agen: AsyncIterator[Sequence[tuple[int, Buffer | None]]], +) -> list[list[tuple[int, Buffer | None]]]: + """Drain an async generator of groups into a list of lists of tuples.""" + return [list(group) async for group in agen] + + +def _contents(groups: list[list[tuple[int, Buffer | None]]]) -> dict[int, bytes]: + """Flatten to {index: bytes}.""" + result: dict[int, bytes] = {} + for group in groups: + for idx, buf in group: + assert buf is not None + result[idx] = buf.to_bytes() + return result + + +# --------------------------------------------------------------------------- +# Shared coalescing-knob bundles. Each is a complete mapping of all three +# kwargs to splat into `coalesced_get`; `coalesce_ranges` ignores +# `max_concurrency`. The leaf functions in `_coalesce.py` require all knobs +# explicitly — `Store.get_ranges` is the public entry point and owns the +# canonical defaults. Tests pick their own values appropriate to the scenario. +# --------------------------------------------------------------------------- + +# Permissive default for tests that don't care about specific thresholds. Mirrors +# `Store.get_ranges`'s public defaults but the test file owns this independently +# of any production constants. +DEFAULT: Mapping[str, int] = { + "max_concurrency": 10, + "max_gap_bytes": 1 << 20, + "max_coalesced_bytes": 16 << 20, +} +"""Permissive defaults; mirrors `Store.get_ranges`'s baseline.""" + +MERGE_GAP_50: Mapping[str, int] = { + "max_concurrency": 10, + "max_gap_bytes": 50, + "max_coalesced_bytes": 1 << 20, +} +"""Merge ranges within 50 bytes of each other.""" + +NO_MERGE: Mapping[str, int] = { + "max_concurrency": 10, + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, +} +"""No merging: any positive gap is > -1, so no pair ever coalesces.""" + +CAP_50: Mapping[str, int] = { + "max_concurrency": 10, + "max_gap_bytes": 1000, + "max_coalesced_bytes": 50, +} +"""Gap permissive but merged size capped at 50 bytes.""" + + +def _grouping(opts: Mapping[str, int]) -> dict[str, int]: + """Return only the grouping knobs from a full options bundle. + + `coalesce_ranges` rejects `max_concurrency`; this lets test bundles be + full kwargs maps (for `coalesced_get`) and still be passed to the pure + planner via splat. + """ + return {k: v for k, v in opts.items() if k != "max_concurrency"} + + +# A deterministic blob used for content-sensitive cases: byte i == (i % 256). +_INDEXED_BLOB = bytes(i % 256 for i in range(10_000)) + + +# --------------------------------------------------------------------------- +# Parametrized structural/content tests (cases without async timing or errors). +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class StructuralCase: + """One row of the parametrized structure-and-contents table.""" + + id: str + """pytest id for the case.""" + ranges: list[ByteRequest | None] + """Input to coalesced_get.""" + options: Mapping[str, int] + """Coalescing knobs to splat into `coalesced_get`.""" + expected_group_sizes: list[int] + """Sorted list of group tuple-counts (order-independent).""" + expected_contents: dict[int, bytes] | None = None + """{input_index: bytes} to verify bytes, or None to skip the content check.""" + expected_n_fetches: int | None = None + """Exact number of calls to the fetch callable, or None to skip the check.""" + + +_STRUCTURAL_CASES: list[StructuralCase] = [ + StructuralCase( + id="empty-input", + ranges=[], + options=DEFAULT, + expected_group_sizes=[], + expected_n_fetches=0, + ), + StructuralCase( + id="single-range", + ranges=[RangeByteRequest(2, 5)], + options=DEFAULT, + expected_group_sizes=[1], + expected_contents={0: _INDEXED_BLOB[2:5]}, + expected_n_fetches=1, + ), + StructuralCase( + id="disjoint-3-no-merge", + ranges=[ + RangeByteRequest(0, 10), + RangeByteRequest(200, 210), + RangeByteRequest(500, 510), + ], + options=MERGE_GAP_50, + expected_group_sizes=[1, 1, 1], + expected_contents={ + 0: _INDEXED_BLOB[0:10], + 1: _INDEXED_BLOB[200:210], + 2: _INDEXED_BLOB[500:510], + }, + expected_n_fetches=3, + ), + StructuralCase( + id="adjacent-3-one-merged-group", + ranges=[ + RangeByteRequest(0, 5), + RangeByteRequest(10, 15), + RangeByteRequest(20, 25), + ], + options=MERGE_GAP_50, + expected_group_sizes=[3], + expected_contents={ + 0: _INDEXED_BLOB[0:5], + 1: _INDEXED_BLOB[10:15], + 2: _INDEXED_BLOB[20:25], + }, + expected_n_fetches=1, + ), + StructuralCase( + id="two-clusters-one-singleton", + ranges=[ + RangeByteRequest(0, 10), + RangeByteRequest(20, 30), + RangeByteRequest(500, 510), + ], + options=MERGE_GAP_50, + expected_group_sizes=[1, 2], + expected_contents={ + 0: _INDEXED_BLOB[0:10], + 1: _INDEXED_BLOB[20:30], + 2: _INDEXED_BLOB[500:510], + }, + expected_n_fetches=2, + ), + StructuralCase( + id="uncoalescable-mixed-with-range", + ranges=[ + RangeByteRequest(0, 3), + OffsetByteRequest(5), + SuffixByteRequest(2), + None, + ], + options=DEFAULT, + expected_group_sizes=[1, 1, 1, 1], + expected_contents={ + 0: _INDEXED_BLOB[0:3], + 1: _INDEXED_BLOB[5:], + 2: _INDEXED_BLOB[-2:], + 3: _INDEXED_BLOB, + }, + expected_n_fetches=4, + ), + StructuralCase( + id="shuffled-input-indices-preserved", + ranges=[ + RangeByteRequest(500, 510), + RangeByteRequest(0, 10), + RangeByteRequest(200, 210), + RangeByteRequest(300, 310), + ], + options=MERGE_GAP_50, + expected_group_sizes=[1, 1, 1, 1], + expected_contents={ + 0: _INDEXED_BLOB[500:510], + 1: _INDEXED_BLOB[0:10], + 2: _INDEXED_BLOB[200:210], + 3: _INDEXED_BLOB[300:310], + }, + expected_n_fetches=4, + ), + StructuralCase( + id="cap-prevents-merge-of-close-ranges", + # 20 + 20 gap + 20 = 60-byte merged span > cap of 50. + ranges=[RangeByteRequest(0, 20), RangeByteRequest(40, 60)], + options=CAP_50, + expected_group_sizes=[1, 1], + expected_n_fetches=2, + ), + StructuralCase( + id="single-range-larger-than-cap-passes-through", + # Cap only applies to MERGE decisions; a lone oversized range still fetches. + ranges=[RangeByteRequest(0, 200)], + options=CAP_50, + expected_group_sizes=[1], + expected_contents={0: _INDEXED_BLOB[0:200]}, + expected_n_fetches=1, + ), +] + + +@pytest.mark.parametrize("case", _STRUCTURAL_CASES, ids=lambda c: c.id) +async def test_coalescing_structure_and_contents(case: StructuralCase) -> None: + """Group structure, byte contents, and fetch-call count for the deterministic cases.""" + fetch = FakeFetch(_INDEXED_BLOB) + groups = await _collect(coalesced_get(fetch, case.ranges, **case.options)) + + assert sorted(len(g) for g in groups) == sorted(case.expected_group_sizes) + + if case.expected_contents is not None: + assert _contents(groups) == case.expected_contents + + if case.expected_n_fetches is not None: + assert len(fetch.calls) == case.expected_n_fetches + + +# --------------------------------------------------------------------------- +# Focused non-parametrized tests for cases with distinctive assertion shapes. +# --------------------------------------------------------------------------- + + +async def test_within_group_ordering_is_start_offset() -> None: + """Within a merged group, tuples are ordered by start offset, not input order.""" + fetch = FakeFetch(_INDEXED_BLOB) + # Two ranges that merge; one has a later start but is listed first in input. + ranges: list[ByteRequest | None] = [RangeByteRequest(20, 25), RangeByteRequest(0, 5)] + groups = await _collect(coalesced_get(fetch, ranges, **MERGE_GAP_50)) + assert len(groups) == 1 + # Input index 1 (start=0) comes first, then 0 (start=20). + assert [idx for idx, _ in groups[0]] == [1, 0] + + +async def test_adjacent_ranges_fire_single_fetch_spanning_merged_region() -> None: + """Verify the merged fetch covers exactly the span from min-start to max-end.""" + fetch = FakeFetch(_INDEXED_BLOB) + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 5), + RangeByteRequest(10, 15), + RangeByteRequest(20, 25), + ] + await _collect(coalesced_get(fetch, ranges, **MERGE_GAP_50)) + assert len(fetch.calls) == 1 + call = fetch.calls[0] + assert isinstance(call, RangeByteRequest) + assert call.start == 0 + assert call.end == 25 + + +# --------------------------------------------------------------------------- +# Concurrency and cancellation. +# --------------------------------------------------------------------------- + + +async def test_max_concurrency_is_honored() -> None: + """With 10 non-mergeable ranges and max_concurrency=3, peak in-flight must not exceed 3.""" + in_flight = 0 + peak = 0 + lock = asyncio.Lock() + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + nonlocal in_flight, peak + async with lock: + in_flight += 1 + peak = max(peak, in_flight) + # give the scheduler a chance to run other tasks + await asyncio.sleep(0.01) + async with lock: + in_flight -= 1 + return _buf(b"x") + + ranges: list[ByteRequest | None] = [RangeByteRequest(i * 1000, i * 1000 + 1) for i in range(10)] + opts: Mapping[str, int] = { + "max_gap_bytes": 0, # force no merging + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 3, + } + async for _group in coalesced_get(fetch, ranges, **opts): + pass + assert peak <= 3 + assert peak >= 2 # must have been some real concurrency + + +async def test_consumer_break_cancels_pending_fetches() -> None: + """Breaking out of the async for should cancel pending fetches rather than let them complete.""" + completed_calls = 0 + cancelled_calls = 0 + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + nonlocal completed_calls, cancelled_calls + assert isinstance(byte_range, RangeByteRequest) + start = byte_range.start + try: + # First fetch returns fast so the async-for body runs and can break. + # Later fetches sleep long enough that cancellation has room to land. + await asyncio.sleep(0.001 if start == 0 else 2.0) + except asyncio.CancelledError: + cancelled_calls += 1 + raise + completed_calls += 1 + return _buf(b"x") + + opts: Mapping[str, int] = { + "max_gap_bytes": -1, # no merging + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 3, + } + ranges: list[ByteRequest | None] = [RangeByteRequest(i * 1000, i * 1000 + 1) for i in range(6)] + + agen = coalesced_get(fetch, ranges, **opts) + async for _group in agen: + break + # Explicitly close the generator so its finally block runs (cancelling + # in-flight tasks) before we make assertions. + await agen.aclose() + + # The fast task completes; the remaining tasks are either cancelled while + # sleeping (raising CancelledError into the user try block) or cancelled + # while still waiting on the semaphore (which doesn't enter the try at all). + # Either way, none of them should be allowed to complete. + assert completed_calls == 1 + assert cancelled_calls >= 1 + assert completed_calls + cancelled_calls <= len(ranges) + + +# --------------------------------------------------------------------------- +# Key-missing semantics. +# --------------------------------------------------------------------------- + + +async def test_key_missing_from_first_call_raises() -> None: + """If the very first fetch returns None, the iterator raises an ExceptionGroup containing FileNotFoundError.""" + fetch = FakeFetch(b"x" * 100, key_exists=False) + ranges: list[ByteRequest | None] = [RangeByteRequest(0, 10), RangeByteRequest(20, 30)] + with pytest.RaisesGroup(pytest.RaisesExc(FileNotFoundError)): + await _collect(coalesced_get(fetch, ranges, **DEFAULT)) + + +@pytest.mark.parametrize( + "byte_range", + [OffsetByteRequest(5), SuffixByteRequest(5), None], + ids=["offset", "suffix", "none"], +) +async def test_key_missing_on_uncoalescable_input_raises( + byte_range: ByteRequest | None, +) -> None: + """Uncoalescable inputs take a distinct path; key-missing must still raise (wrapped in a group).""" + fetch = FakeFetch(b"x" * 100, key_exists=False) + with pytest.RaisesGroup(pytest.RaisesExc(FileNotFoundError)): + await _collect(coalesced_get(fetch, [byte_range], **DEFAULT)) + + +async def test_key_missing_mid_stream_raises_after_earlier_groups() -> None: + """If a later fetch returns None, earlier-completed groups yield before the raise.""" + call_count = 0 + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + nonlocal call_count + call_count += 1 + # Deterministic: first call serves, second returns None. + await asyncio.sleep(0.01 if call_count == 1 else 0.02) + if call_count >= 2: + return None + return _buf(b"ok") + + opts: Mapping[str, int] = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 1, # serialize for determinism + } + ranges: list[ByteRequest | None] = [RangeByteRequest(0, 2), RangeByteRequest(100, 102)] + agen = coalesced_get(fetch, ranges, **opts) + first = await anext(agen) + assert len(first) == 1 + with pytest.RaisesGroup(pytest.RaisesExc(FileNotFoundError)): + await anext(agen) + + +async def test_key_missing_mid_stream_with_concurrency_cancels_late_arrivals() -> None: + """ + Under max_concurrency > 1, a mid-stream miss should raise FileNotFoundError + and cancel still-in-flight unrelated tasks rather than wait for them. + """ + late_gate = asyncio.Event() + miss_fired = asyncio.Event() + # Driven by the test body after the first successful yield, so the miss + # task can't race past the start=0 result. + fire_miss = asyncio.Event() + + async def fetch(byte_range: ByteRequest | None) -> Buffer | None: + assert isinstance(byte_range, RangeByteRequest) + start = byte_range.start + if start == 0: + return _buf(b"ok") + if start == 1000: + # Wait for the test to give the green light before returning None. + # This makes ordering deterministic regardless of scheduling. + await asyncio.wait_for(fire_miss.wait(), timeout=5.0) + miss_fired.set() + return None + # Late arrivals would block on this gate; they should be cancelled + # before they ever return. + await asyncio.wait_for(late_gate.wait(), timeout=5.0) + return _buf(b"ok") + + opts: Mapping[str, int] = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 3, + } + ranges: list[ByteRequest | None] = [RangeByteRequest(i * 1000, i * 1000 + 1) for i in range(7)] + + agen = coalesced_get(fetch, ranges, **opts) + first = await anext(agen) + assert len(first) == 1 + idx, buf = first[0] + assert idx == 0 + assert buf is not None + # Now that #0 has yielded, signal the miss task to return None. + fire_miss.set() + with pytest.RaisesGroup(pytest.RaisesExc(FileNotFoundError)): + await anext(agen) + assert miss_fired.is_set() + # Sanity: late_gate was never set, so the cancellation path is what completed the test. + assert not late_gate.is_set() + + +# --------------------------------------------------------------------------- +# Error propagation. +# --------------------------------------------------------------------------- + + +async def test_fetch_raises_propagates() -> None: + """An exception raised by fetch propagates on the yield that produced the failing group.""" + fetch = FakeFetch( + _INDEXED_BLOB, + raise_on=lambda r: isinstance(r, RangeByteRequest) and r.start >= 100, + ) + opts: Mapping[str, int] = { + "max_gap_bytes": -1, + "max_coalesced_bytes": 1 << 20, + "max_concurrency": 1, + } + ranges: list[ByteRequest | None] = [RangeByteRequest(0, 10), RangeByteRequest(200, 210)] + with pytest.RaisesGroup(pytest.RaisesExc(OSError, match="injected")): + await _collect(coalesced_get(fetch, ranges, **opts)) + + +# --------------------------------------------------------------------------- +# Property-style coverage invariant. +# --------------------------------------------------------------------------- + + +async def test_coverage_invariant_random_inputs() -> None: + """For any random RangeByteRequest input, every input index appears exactly once.""" + import random + + rng = random.Random(42) + fetch = FakeFetch(_INDEXED_BLOB) + + ranges: list[ByteRequest | None] = [] + for _ in range(50): + start = rng.randint(0, 9000) + length = rng.randint(1, 500) + ranges.append(RangeByteRequest(start, start + length)) + + groups = await _collect(coalesced_get(fetch, ranges, **DEFAULT)) + seen: list[int] = [idx for group in groups for idx, _buf in group] + assert sorted(seen) == list(range(len(ranges))) + + flat = _contents(groups) + for i, r in enumerate(ranges): + assert isinstance(r, RangeByteRequest) + assert flat[i] == _INDEXED_BLOB[r.start : r.end] + + +# --------------------------------------------------------------------------- +# Pure-function tests for coalesce_ranges (no async, no fetch). +# --------------------------------------------------------------------------- + + +def test_coalesce_ranges_empty_input() -> None: + groups, uncoalescable = coalesce_ranges([], max_gap_bytes=1 << 20, max_coalesced_bytes=16 << 20) + assert groups == [] + assert uncoalescable == [] + + +def test_coalesce_ranges_separates_coalescable_from_uncoalescable() -> None: + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 10), + OffsetByteRequest(100), + SuffixByteRequest(5), + None, + RangeByteRequest(20, 30), + ] + groups, uncoalescable = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + + # Both range requests fall within MERGE_GAP_50's gap budget. + assert len(groups) == 1 + assert [idx for idx, _ in groups[0]] == [0, 4] + + # Non-RangeByteRequest entries preserve their original input indices. + assert [(idx, type(req).__name__ if req else None) for idx, req in uncoalescable] == [ + (1, "OffsetByteRequest"), + (2, "SuffixByteRequest"), + (3, None), + ] + + +def test_coalesce_ranges_no_merge_when_gap_exceeds_budget() -> None: + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 10), + RangeByteRequest(200, 210), + RangeByteRequest(500, 510), + ] + groups, uncoalescable = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + assert uncoalescable == [] + assert [len(g) for g in groups] == [1, 1, 1] + assert [idx for g in groups for idx, _ in g] == [0, 1, 2] + + +def test_coalesce_ranges_merges_within_gap_budget() -> None: + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 5), + RangeByteRequest(10, 15), + RangeByteRequest(20, 25), + ] + groups, _ = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + assert len(groups) == 1 + assert [idx for idx, _ in groups[0]] == [0, 1, 2] + + +def test_coalesce_ranges_respects_max_coalesced_bytes() -> None: + # Gap budget is permissive (1000), but the merged span would exceed CAP_50's + # 50-byte cap, so the second range starts a new group. + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 30), + RangeByteRequest(40, 80), + ] + groups, _ = coalesce_ranges(ranges, **_grouping(CAP_50)) + assert [len(g) for g in groups] == [1, 1] + + +def test_coalesce_ranges_groups_are_sorted_by_start() -> None: + """Input order is irrelevant; groups always emerge in start-offset order.""" + ranges: list[ByteRequest | None] = [ + RangeByteRequest(500, 510), + RangeByteRequest(0, 10), + RangeByteRequest(20, 30), + RangeByteRequest(200, 210), + ] + groups, _ = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + # First group is the {0-10, 20-30} cluster (from input indices 1, 2). + # Then the {200-210} singleton, then {500-510}. + flat = [idx for g in groups for idx, _ in g] + assert flat == [1, 2, 3, 0] + # Within each group, members are sorted by start. + for g in groups: + starts = [r.start for _, r in g] + assert starts == sorted(starts) + + +def test_coalesce_ranges_overlapping_ranges_merge() -> None: + """Nested/overlapping ranges have a non-positive 'gap' and always merge.""" + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 100), + RangeByteRequest(50, 60), # nested + RangeByteRequest(80, 120), # overlaps + ] + groups, _ = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + assert len(groups) == 1 + assert [idx for idx, _ in groups[0]] == [0, 1, 2] + + +def test_coalesce_ranges_running_end_handles_nesting() -> None: + """A subsequent range fully inside the running span must not extend group_end backwards.""" + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 1000), # group_end=1000 + RangeByteRequest(100, 200), # nested; group_end stays at 1000 + RangeByteRequest(990, 1010), # gap = -10 from running end, still merges + ] + groups, _ = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + assert len(groups) == 1 + assert {idx for idx, _ in groups[0]} == {0, 1, 2} + + +def test_coalesce_ranges_only_uncoalescable_inputs() -> None: + ranges: list[ByteRequest | None] = [None, OffsetByteRequest(10), SuffixByteRequest(5)] + groups, uncoalescable = coalesce_ranges( + ranges, max_gap_bytes=1 << 20, max_coalesced_bytes=16 << 20 + ) + assert groups == [] + assert [idx for idx, _ in uncoalescable] == [0, 1, 2] + + +def test_coalesce_ranges_total_index_coverage() -> None: + """Every input index appears exactly once across groups + uncoalescable.""" + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 10), + None, + RangeByteRequest(15, 25), + OffsetByteRequest(100), + RangeByteRequest(30, 40), + ] + groups, uncoalescable = coalesce_ranges(ranges, **_grouping(MERGE_GAP_50)) + seen = sorted([idx for g in groups for idx, _ in g] + [idx for idx, _ in uncoalescable]) + assert seen == list(range(len(ranges))) diff --git a/tests/test_store/test_fsspec_get_ranges.py b/tests/test_store/test_fsspec_get_ranges.py new file mode 100644 index 0000000000..61d834d22f --- /dev/null +++ b/tests/test_store/test_fsspec_get_ranges.py @@ -0,0 +1,124 @@ +# tests/test_store/test_fsspec_get_ranges.py +"""Lightweight integration tests for FsspecStore.get_ranges using MemoryFileSystem. + +These don't need moto/s3 — they exercise the new method against an in-process +fsspec MemoryFileSystem wrapped in the async wrapper. +""" + +from __future__ import annotations + +import pytest +from packaging.version import parse as parse_version + +from zarr.abc.store import RangeByteRequest +from zarr.core.buffer import Buffer, default_buffer_prototype +from zarr.storage import FsspecStore +from zarr.storage._fsspec import _make_async + +fsspec = pytest.importorskip("fsspec") + +# AsyncFileSystemWrapper (needed to wrap a sync MemoryFileSystem) landed in fsspec 2024.12.0. +# Older versions are pinned by the min-deps CI job, so skip the whole file there. +pytestmark = pytest.mark.skipif( + parse_version(fsspec.__version__) < parse_version("2024.12.0"), + reason="No AsyncFileSystemWrapper", +) + + +@pytest.fixture +def memory_store() -> FsspecStore: + """An FsspecStore backed by fsspec MemoryFileSystem (wrapped async).""" + from fsspec.implementations.memory import MemoryFileSystem + + # Each test gets a clean filesystem; MemoryFileSystem is a singleton per target_options, + # so clear state explicitly. + fs: MemoryFileSystem = MemoryFileSystem() + fs.store.clear() + fs.pseudo_dirs.clear() + async_fs = _make_async(fs) + return FsspecStore(fs=async_fs, path="/root") + + +async def _write(store: FsspecStore, key: str, data: bytes) -> None: + buf = default_buffer_prototype().buffer.from_bytes(data) + await store.set(key, buf) + + +async def test_get_ranges_happy_path(memory_store: FsspecStore) -> None: + blob = bytes(i % 256 for i in range(1024)) + await _write(memory_store, "blob", blob) + proto = default_buffer_prototype() + + ranges = [ + RangeByteRequest(0, 10), + RangeByteRequest(100, 110), + RangeByteRequest(500, 520), + ] + groups: list[list[tuple[int, Buffer | None]]] = [ + list(group) async for group in memory_store.get_ranges("blob", ranges, prototype=proto) + ] + + flat: dict[int, bytes] = {} + for group in groups: + for idx, buf in group: + assert buf is not None + flat[idx] = buf.to_bytes() + + assert flat[0] == blob[0:10] + assert flat[1] == blob[100:110] + assert flat[2] == blob[500:520] + + +async def test_get_ranges_missing_key_raises(memory_store: FsspecStore) -> None: + """A request against a missing key raises BaseExceptionGroup containing FileNotFoundError.""" + proto = default_buffer_prototype() + agen = memory_store.get_ranges("does-not-exist", [RangeByteRequest(0, 10)], prototype=proto) + with pytest.RaisesGroup(pytest.RaisesExc(FileNotFoundError)): + await anext(agen) + + +async def test_get_ranges_forwards_coalescing_kwargs(memory_store: FsspecStore) -> None: + """`max_gap_bytes=-1` forces no merging; we should see three groups for three ranges.""" + blob = bytes(i % 256 for i in range(1024)) + await _write(memory_store, "blob", blob) + proto = default_buffer_prototype() + + ranges = [ + RangeByteRequest(0, 10), + RangeByteRequest(11, 20), # adjacent: would merge under defaults + RangeByteRequest(21, 30), + ] + groups: list[list[tuple[int, Buffer | None]]] = [ + list(group) + async for group in memory_store.get_ranges( + "blob", ranges, prototype=proto, max_gap_bytes=-1 + ) + ] + # With merging disabled, every range becomes its own one-tuple group. + assert sorted(len(g) for g in groups) == [1, 1, 1] + + +async def test_get_ranges_mixed_range_types(memory_store: FsspecStore) -> None: + """Covers RangeByteRequest, OffsetByteRequest, SuffixByteRequest, and None in one call.""" + from zarr.abc.store import ByteRequest, OffsetByteRequest, SuffixByteRequest + + blob = bytes(i % 256 for i in range(512)) + await _write(memory_store, "mixed", blob) + proto = default_buffer_prototype() + + ranges: list[ByteRequest | None] = [ + RangeByteRequest(0, 10), + OffsetByteRequest(500), + SuffixByteRequest(12), + None, + ] + flat: dict[int, bytes] = {} + async for group in memory_store.get_ranges("mixed", ranges, prototype=proto): + for idx, buf in group: + assert buf is not None + flat[idx] = buf.to_bytes() + + assert flat[0] == blob[0:10] + assert flat[1] == blob[500:] + assert flat[2] == blob[-12:] + assert flat[3] == blob diff --git a/tests/test_store/test_get_ranges.py b/tests/test_store/test_get_ranges.py new file mode 100644 index 0000000000..8f0c6a4814 --- /dev/null +++ b/tests/test_store/test_get_ranges.py @@ -0,0 +1,142 @@ +# tests/test_store/test_get_ranges.py +"""Tests for `Store.get_ranges` — the ABC default implementation and wrapper delegation. + +`Store.get_ranges` is defined on the ABC with a default implementation built +on `coalesced_get(self.get, ...)`, so every store inherits a working version. +These tests cover that inherited path and the explicit delegation in +`WrapperStore` (which ensures wrapped stores' optimized overrides are honored). +Store-specific overrides (e.g. `FsspecStore`) have their own test modules. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from zarr.abc.store import RangeByteRequest +from zarr.core.buffer import default_buffer_prototype +from zarr.storage import MemoryStore +from zarr.storage._wrapper import WrapperStore + +if TYPE_CHECKING: + from collections.abc import AsyncIterator, Sequence + + from zarr.abc.store import ByteRequest + from zarr.core.buffer import Buffer, BufferPrototype + + +async def _write(store: MemoryStore, key: str, data: bytes) -> None: + buf = default_buffer_prototype().buffer.from_bytes(data) + await store.set(key, buf) + + +async def test_memory_store_inherits_get_ranges_from_abc() -> None: + """MemoryStore doesn't override `get_ranges`; the ABC default must work end-to-end.""" + store = MemoryStore() + blob = bytes(i % 256 for i in range(512)) + await _write(store, "blob", blob) + + ranges = [RangeByteRequest(0, 10), RangeByteRequest(100, 110)] + proto = default_buffer_prototype() + flat: dict[int, bytes] = {} + async for group in store.get_ranges("blob", ranges, prototype=proto): + for idx, buf in group: + assert buf is not None + flat[idx] = buf.to_bytes() + + assert flat[0] == blob[0:10] + assert flat[1] == blob[100:110] + + +async def test_memory_store_get_ranges_missing_key_raises() -> None: + """A missing key on a default-impl store raises BaseExceptionGroup containing FileNotFoundError.""" + store = MemoryStore() + proto = default_buffer_prototype() + agen = store.get_ranges("does-not-exist", [RangeByteRequest(0, 10)], prototype=proto) + with pytest.RaisesGroup(pytest.RaisesExc(FileNotFoundError)): + await anext(agen) + + +async def test_wrapper_store_delegates_get_ranges() -> None: + """WrapperStore.get_ranges must delegate to the wrapped store, not fall back to the default.""" + + class CountingMemoryStore(MemoryStore): + """Tallies get_ranges invocations so we can assert delegation.""" + + get_ranges_calls: int = 0 + + async def get_ranges( + self, + key: str, + byte_ranges: Sequence[ByteRequest | None], + *, + prototype: BufferPrototype, + max_concurrency: int = 10, + max_gap_bytes: int = 1 << 20, + max_coalesced_bytes: int = 16 << 20, + ) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]: + type(self).get_ranges_calls += 1 + async for group in super().get_ranges( + key, + byte_ranges, + prototype=prototype, + max_concurrency=max_concurrency, + max_gap_bytes=max_gap_bytes, + max_coalesced_bytes=max_coalesced_bytes, + ): + yield group + + inner = CountingMemoryStore() + blob = b"x" * 100 + await _write(inner, "k", blob) + wrapped = WrapperStore(inner) + + proto = default_buffer_prototype() + groups: list[list[tuple[int, Buffer | None]]] = [ + list(group) + async for group in wrapped.get_ranges("k", [RangeByteRequest(0, 5)], prototype=proto) + ] + + assert CountingMemoryStore.get_ranges_calls == 1 + assert len(groups) == 1 + assert groups[0][0][0] == 0 + + +async def test_wrapper_store_forwards_coalescing_kwargs() -> None: + """Coalescing kwargs flow through WrapperStore to the wrapped store's get_ranges.""" + + class SpyMemoryStore(MemoryStore): + last_max_gap_bytes: int | None = None + + async def get_ranges( + self, + key: str, + byte_ranges: Sequence[ByteRequest | None], + *, + prototype: BufferPrototype, + max_concurrency: int = 10, + max_gap_bytes: int = 1 << 20, + max_coalesced_bytes: int = 16 << 20, + ) -> AsyncIterator[Sequence[tuple[int, Buffer | None]]]: + type(self).last_max_gap_bytes = max_gap_bytes + async for group in super().get_ranges( + key, + byte_ranges, + prototype=prototype, + max_concurrency=max_concurrency, + max_gap_bytes=max_gap_bytes, + max_coalesced_bytes=max_coalesced_bytes, + ): + yield group + + inner = SpyMemoryStore() + await _write(inner, "k", b"y" * 100) + wrapped = WrapperStore(inner) + proto = default_buffer_prototype() + async for _ in wrapped.get_ranges( + "k", [RangeByteRequest(0, 5)], prototype=proto, max_gap_bytes=-1 + ): + pass + + assert SpyMemoryStore.last_max_gap_bytes == -1