Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 38 additions & 9 deletions src/datasets/data_files.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import os
import re
from functools import partial
Expand All @@ -19,6 +20,13 @@
from .utils import tqdm as hf_tqdm
from .utils.file_utils import _prepare_path_and_storage_options, is_local_path, is_relative_path, xbasename, xjoin
from .utils.py_utils import string_to_dict
import threading

_origin_metadata_cache: dict[tuple[str, Optional[str]], SingleOriginMetadata] = {}

_hffs_cache: dict[Optional[str], HfFileSystem] = {}

_cache_lock = threading.Lock()


SingleOriginMetadata = Union[tuple[str, str], tuple[str], tuple[()]]
Expand Down Expand Up @@ -486,22 +494,43 @@ def _get_single_origin_metadata(
data_file: str,
download_config: Optional[DownloadConfig] = None,
) -> SingleOriginMetadata:
cache_key = (data_file, getattr(download_config, "token", None))
# Fast path: return cached value if present
with _cache_lock:
cached = _origin_metadata_cache.get(cache_key)
if cached is not None:
return cached

data_file, storage_options = _prepare_path_and_storage_options(data_file, download_config=download_config)
fs, *_ = url_to_fs(data_file, **storage_options)
if isinstance(fs, HfFileSystem):
resolved_path = fs.resolve_path(data_file)
return resolved_path.repo_id, resolved_path.revision
result = (resolved_path.repo_id, resolved_path.revision)
elif data_file.startswith(config.HF_ENDPOINT):
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=download_config.token)
token = getattr(download_config, "token", None)
# Reuse a single HfFileSystem per token to avoid reconnecting repeatedly
with _cache_lock:
hffs = _hffs_cache.get(token)
if hffs is None:
hffs = HfFileSystem(endpoint=config.HF_ENDPOINT, token=token)
_hffs_cache[token] = hffs
data_file = "hf://" + data_file[len(config.HF_ENDPOINT) + 1 :].replace("/resolve/", "@", 1)
resolved_path = hffs.resolve_path(data_file)
return resolved_path.repo_id, resolved_path.revision
info = fs.info(data_file)
# s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
for key in ["ETag", "etag", "mtime"]:
if key in info:
return (str(info[key]),)
return ()
result = (resolved_path.repo_id, resolved_path.revision)
else:
info = fs.info(data_file)
# s3fs uses "ETag", gcsfs uses "etag", and for local we simply check mtime
for key in ["ETag", "etag", "mtime"]:
if key in info:
result = (str(info[key]),)
break
else:
result = ()

# Store computed result in cache for future calls
with _cache_lock:
_origin_metadata_cache[cache_key] = result
return result


def _get_origin_metadata(
Expand Down