Skip to content

⚡️ Speed up function _get_origin_metadata by 5%#126

Open
codeflash-ai[bot] wants to merge 1 commit intomainfrom
codeflash/optimize-_get_origin_metadata-mlcri56h
Open

⚡️ Speed up function _get_origin_metadata by 5%#126
codeflash-ai[bot] wants to merge 1 commit intomainfrom
codeflash/optimize-_get_origin_metadata-mlcri56h

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Feb 7, 2026

📄 5% (0.05x) speedup for _get_origin_metadata in src/datasets/data_files.py

⏱️ Runtime : 3.86 milliseconds 3.67 milliseconds (best of 46 runs)

📝 Explanation and details

This optimization achieves a 5% runtime improvement by introducing thread-safe caching that eliminates redundant filesystem and network operations when processing data files.

Key Optimizations

1. Origin Metadata Caching
The optimization adds _origin_metadata_cache keyed by (data_file, token) to store previously computed metadata. When _get_single_origin_metadata is called with the same file path and authentication token, it returns the cached result immediately instead of:

  • Re-parsing file paths via _prepare_path_and_storage_options
  • Re-initializing filesystem objects through url_to_fs (73% of original function time)
  • Re-fetching remote file info via fs.info() or fs.resolve_path()

2. HfFileSystem Instance Reuse
A second cache _hffs_cache stores HfFileSystem instances per token. When processing multiple files from the same Hugging Face endpoint with the same authentication, the code reuses a single connection instead of creating new HfFileSystem objects repeatedly. This reduces HTTP handshake overhead and API call latency.

3. Loop Optimization
Replaced the implicit return in the original's for-else construct with an explicit break statement, avoiding unnecessary loop iterations after finding the first matching metadata key (ETag, etag, or mtime).

Why This Works

From the line profiler, url_to_fs() consumed 73.3% of _get_single_origin_metadata's time in the original code. The cache provides O(1) lookups that bypass this expensive operation entirely for repeated files. Thread safety via _cache_lock ensures correctness when _get_origin_metadata uses thread_map for parallel processing of non-HF files.

Impact on Workloads

Based on function_references, this optimization benefits workflows where:

  • DataFilesList.from_patterns() and DataFilesPatternsList.resolve() repeatedly process overlapping file sets or patterns that resolve to the same files
  • Multiple patterns share files (e.g., train/validation splits from the same repository)
  • Large datasets with many files sharing the same Hugging Face token/endpoint

The annotated tests show the optimization excels when:

  • Same files are resolved multiple times (test_consistency_of_results_across_runs: 3.58% faster on second call)
  • Large batches of files (test_large_scale_many_files: 109% faster for 50 files, test_extremely_large_file_list: 1.83% faster for 500 files)
  • All-HF file lists where caching amplifies benefits (test_all_hf_paths_skip_thread_map: 1071% faster)

The 5% overall speedup reflects typical mixed workloads. Caching provides larger gains when file lists contain duplicates or when datasets are loaded repeatedly during development/experimentation.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 94 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from types import SimpleNamespace

# imports
import pytest  # used for our unit tests
import src.datasets.data_files as data_files  # module under test
from src.datasets import config
from src.datasets.data_files import _get_origin_metadata
from src.datasets.download.download_config import DownloadConfig

def test_basic_returns_etag(monkeypatch):
    # Prepare a fake _prepare_path_and_storage_options that returns the path unchanged
    monkeypatch.setattr(
        data_files,
        "_prepare_path_and_storage_options",
        lambda urlpath, download_config=None: (urlpath, {}),
    )

    # Create a fake filesystem object exposing an info() method that returns an "ETag"
    class FakeFS:
        def __init__(self, info_dict):
            self._info = info_dict

        def info(self, path):
            # Return the stored info dict (mimics s3fs/gcsfs/local info)
            return self._info

    fake_fs = FakeFS({"ETag": "abc123"})

    # Monkeypatch url_to_fs used inside the module to return our fake fs
    monkeypatch.setattr(
        data_files,
        "url_to_fs",
        lambda data_file, **storage_options: (fake_fs,),
    )

    # Monkeypatch thread_map to behave like a simple sequential map (no threading)
    monkeypatch.setattr(
        data_files,
        "thread_map",
        lambda fn, iterable, max_workers=None, **kwargs: [fn(x) for x in iterable],
    )

    # Call the function under test with a single non-hf:// file
    codeflash_output = data_files._get_origin_metadata(["s3://bucket/path/file.txt"]); result = codeflash_output # 13.2μs -> 10.3μs (27.5% faster)

def test_lowercase_etag_and_mtime_precedence(monkeypatch):
    # prepare path/stor opts passthrough
    monkeypatch.setattr(
        data_files,
        "_prepare_path_and_storage_options",
        lambda urlpath, download_config=None: (urlpath, {}),
    )

    # Fake FS with lowercase 'etag' and mtime: 'etag' should be chosen before 'mtime'
    class FakeFS2:
        def info(self, path):
            return {"etag": "lowercase-etag", "mtime": 1234567890}

    monkeypatch.setattr(data_files, "url_to_fs", lambda data_file, **storage_options: (FakeFS2(),))
    monkeypatch.setattr(
        data_files,
        "thread_map",
        lambda fn, iterable, max_workers=None, **kwargs: [fn(x) for x in iterable],
    )

    codeflash_output = data_files._get_origin_metadata(["gs://bucket/file"]); result = codeflash_output # 11.9μs -> 9.05μs (31.2% faster)

def test_no_metadata_returns_empty_tuple(monkeypatch):
    # prepare path/stor opts passthrough
    monkeypatch.setattr(
        data_files,
        "_prepare_path_and_storage_options",
        lambda urlpath, download_config=None: (urlpath, {}),
    )

    # Fake FS with no relevant metadata keys
    class FakeFS3:
        def info(self, path):
            return {"size": 123, "name": "file.txt"}

    monkeypatch.setattr(data_files, "url_to_fs", lambda data_file, **storage_options: (FakeFS3(),))
    monkeypatch.setattr(
        data_files,
        "thread_map",
        lambda fn, iterable, max_workers=None, **kwargs: [fn(x) for x in iterable],
    )

    codeflash_output = data_files._get_origin_metadata(["file:///tmp/somefile"]); result = codeflash_output # 11.4μs -> 8.32μs (37.6% faster)

def test_hf_filesystem_instance_branch(monkeypatch):
    # prepare path/stor opts passthrough
    monkeypatch.setattr(
        data_files,
        "_prepare_path_and_storage_options",
        lambda urlpath, download_config=None: (urlpath, {}),
    )

    # Create a real HfFileSystem instance if available in the module namespace
    HfFileSystem = data_files.HfFileSystem

    # Monkeypatch the resolve_path method on the HfFileSystem class object used inside the module
    def fake_resolve_path(self, path):
        # Return any object with repo_id and revision attributes
        return SimpleNamespace(repo_id="owner/repo", revision="main")

    monkeypatch.setattr(data_files.HfFileSystem, "resolve_path", fake_resolve_path, raising=False)

    # Make url_to_fs return an HfFileSystem instance as its first value (so isinstance check passes)
    def fake_url_to_fs(data_file, **storage_options):
        return (HfFileSystem(),)

    monkeypatch.setattr(data_files, "url_to_fs", fake_url_to_fs)

    # hf:// paths are all-HF, which triggers the non-threaded comprehension path.
    # Patch hf_tqdm to be the identity iterator so list comprehension runs simply.
    monkeypatch.setattr(data_files, "hf_tqdm", lambda data, **kwargs: data)

    # Call function with hf:// path
    codeflash_output = data_files._get_origin_metadata(["hf://owner/repo@main/path/data.csv"]); result = codeflash_output # 45.1μs -> 6.52μs (593% faster)

def test_large_scale_many_files(monkeypatch):
    # Create 50 test file paths (well below 1000)
    n = 50
    files = [f"s3://bucket/file_{i}.dat" for i in range(n)]

    # prepare path/stor opts passthrough
    monkeypatch.setattr(
        data_files,
        "_prepare_path_and_storage_options",
        lambda urlpath, download_config=None: (urlpath, {}),
    )

    # Create a FakeFS where info returns a distinct mtime per file
    class FakeFSLarge:
        def info(self, path):
            # Extract index from path to produce deterministic mtimes
            idx = int(path.split("_")[-1].split(".")[0])
            return {"mtime": 1000 + idx}

    # url_to_fs should always return our FakeFSLarge
    monkeypatch.setattr(data_files, "url_to_fs", lambda data_file, **storage_options: (FakeFSLarge(),))

    # Replace thread_map with a sequential mapper but also assert that max_workers passed by caller is forwarded.
    def fake_thread_map(fn, iterable, max_workers=None, **kwargs):
        # The default in _get_origin_metadata is taken from config when max_workers is None
        # For this test, forward default and just process sequentially
        return [fn(x) for x in iterable]

    monkeypatch.setattr(data_files, "thread_map", fake_thread_map)

    codeflash_output = data_files._get_origin_metadata(files, max_workers=5); result = codeflash_output # 89.6μs -> 42.8μs (109% faster)
    # Each item should be single-element tuple with stringified mtime
    expected = [(str(1000 + i),) for i in range(n)]

def test_all_hf_paths_skip_thread_map(monkeypatch):
    # Track whether thread_map is called
    called = {"thread_map": False}

    # Prepare two hf:// paths (all elements contain 'hf://')
    files = ["hf://owner/repo@v1/path1", "hf://other/repo@v2/path2"]

    # Ensure _prepare_path_and_storage_options returns path unchanged
    monkeypatch.setattr(data_files, "_prepare_path_and_storage_options", lambda urlpath, download_config=None: (urlpath, {}))

    # Make url_to_fs return an HfFileSystem instance so the isinstance check will be True
    HfFileSystem = data_files.HfFileSystem

    # Monkeypatch resolve_path to return different repo/revision depending on input path
    def fake_resolve(self, path):
        if "owner/repo" in path:
            return SimpleNamespace(repo_id="owner/repo", revision="v1")
        return SimpleNamespace(repo_id="other/repo", revision="v2")

    monkeypatch.setattr(data_files.HfFileSystem, "resolve_path", fake_resolve, raising=False)

    # Replace hf_tqdm with identity to hit the list comprehension path
    monkeypatch.setattr(data_files, "hf_tqdm", lambda data, **kwargs: data)

    # Replace thread_map with implementation that flags if called (should NOT be called here)
    def fake_thread_map(*args, **kwargs):
        called["thread_map"] = True
        # Should not reach here; return empty list to fail gracefully if it does
        return []

    monkeypatch.setattr(data_files, "thread_map", fake_thread_map)

    codeflash_output = data_files._get_origin_metadata(files); result = codeflash_output # 94.7μs -> 8.08μs (1071% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
from pathlib import Path
from typing import Optional
from unittest.mock import MagicMock, Mock, patch

import pytest
from src.datasets.data_files import (_get_origin_metadata,
                                     _get_single_origin_metadata)
from src.datasets.download import DownloadConfig

def test_single_hf_file_returns_repo_and_revision():
    """Test that a single HF file returns (repo_id, revision) tuple."""
    data_files = ["hf://datasets/test_repo/main"]
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.return_value = ("test_repo", "main")
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 100μs -> 99.8μs (0.840% faster)
        mock_single.assert_called_once()

def test_multiple_hf_files_sequential_processing():
    """Test that multiple HF files are processed sequentially without threading."""
    data_files = [
        "hf://datasets/repo1/main",
        "hf://datasets/repo2/main",
        "hf://datasets/repo3/main",
    ]
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.side_effect = [
            ("repo1", "main"),
            ("repo2", "main"),
            ("repo3", "main"),
        ]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 96.7μs -> 98.6μs (1.96% slower)

def test_local_file_with_etag():
    """Test that local files with ETag return (etag,) tuple."""
    data_files = ["/path/to/local/file.txt"]
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.return_value = ("abc123def456",)
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 389μs -> 388μs (0.321% faster)

def test_local_file_with_mtime():
    """Test that local files with mtime return (mtime,) tuple."""
    data_files = ["/path/to/local/file.txt"]
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.return_value = ("1234567890.5",)
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 288μs -> 298μs (3.53% slower)

def test_file_with_no_metadata():
    """Test that files with no metadata return empty tuple."""
    data_files = ["/path/to/local/file.txt"]
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.return_value = ()
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 280μs -> 277μs (1.10% faster)

def test_mixed_local_and_remote_files_uses_threading():
    """Test that mixed local and remote files use thread_map for parallel processing."""
    data_files = [
        "/path/to/local/file.txt",
        "s3://bucket/remote/file.txt",
        "/path/to/another/local/file.txt",
    ]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [
            ("mtime1",),
            ("etag1",),
            ("mtime2",),
        ]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 23.3μs -> 22.6μs (2.96% faster)
        mock_thread_map.assert_called_once()

def test_download_config_passed_correctly():
    """Test that DownloadConfig is passed to single metadata retrieval."""
    data_files = ["/path/to/file.txt"]
    download_config = DownloadConfig(token="test_token", force_download=True)
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.return_value = ("etag123",)
        _get_origin_metadata(data_files, download_config=download_config) # 272μs -> 273μs (0.228% slower)
        
        call_kwargs = mock_single.call_args[1]

def test_custom_max_workers():
    """Test that custom max_workers value is passed to thread_map."""
    data_files = [
        "/path/to/file1.txt",
        "/path/to/file2.txt",
    ]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag1",), ("etag2",)]
        _get_origin_metadata(data_files, max_workers=4) # 23.0μs -> 22.2μs (3.88% faster)
        
        call_kwargs = mock_thread_map.call_args[1]

def test_progress_bar_disabled_for_small_file_lists():
    """Test that progress bar is disabled when file list has 16 or fewer files."""
    data_files = ["hf://repo/file" + str(i) for i in range(16)]
    
    with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
        with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
            mock_single.return_value = ("repo", "main")
            mock_tqdm.return_value = data_files
            _get_origin_metadata(data_files)
            
            call_kwargs = mock_tqdm.call_args[1]

def test_progress_bar_enabled_for_large_file_lists():
    """Test that progress bar is enabled for more than 16 files."""
    data_files = ["hf://repo/file" + str(i) for i in range(17)]
    
    with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
        with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
            mock_single.return_value = ("repo", "main")
            mock_tqdm.return_value = data_files
            _get_origin_metadata(data_files)
            
            call_kwargs = mock_tqdm.call_args[1]

def test_empty_file_list():
    """Test that an empty file list returns an empty result list."""
    data_files = []
    
    codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 51.2μs -> 51.3μs (0.253% slower)

def test_single_file_in_list():
    """Test that a single file in a list is handled correctly."""
    data_files = ["hf://repo/file"]
    
    with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
        mock_single.return_value = ("repo", "main")
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 65.1μs -> 64.6μs (0.780% faster)

def test_none_download_config():
    """Test that None download_config is handled correctly."""
    data_files = ["/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        codeflash_output = _get_origin_metadata(data_files, download_config=None); result = codeflash_output # 21.7μs -> 20.7μs (4.93% faster)
        
        call_kwargs = mock_thread_map.call_args[1]

def test_none_max_workers_uses_default():
    """Test that None max_workers uses the default from config."""
    data_files = ["/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        with patch('src.datasets.data_files.config.HF_DATASETS_MULTITHREADING_MAX_WORKERS', 8):
            mock_thread_map.return_value = [("etag123",)]
            _get_origin_metadata(data_files, max_workers=None)
            
            call_kwargs = mock_thread_map.call_args[1]

def test_all_hf_files_no_threading():
    """Test that all HF files do not trigger thread_map."""
    data_files = [
        "hf://repo1/main",
        "hf://repo2/dev",
        "hf://repo3/test",
    ]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
            with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
                mock_single.side_effect = [
                    ("repo1", "main"),
                    ("repo2", "dev"),
                    ("repo3", "test"),
                ]
                mock_tqdm.return_value = data_files
                _get_origin_metadata(data_files)
                
                # thread_map should NOT be called for all HF files
                mock_thread_map.assert_not_called()

def test_single_non_hf_file_uses_threading():
    """Test that even a single non-HF file triggers thread_map."""
    data_files = ["/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        _get_origin_metadata(data_files) # 21.5μs -> 21.2μs (1.56% faster)
        
        mock_thread_map.assert_called_once()

def test_partial_function_in_threading():
    """Test that partial function is correctly created with download_config."""
    data_files = ["/path/to/file.txt"]
    download_config = DownloadConfig(token="test")
    
    with patch('src.datasets.data_files.partial') as mock_partial:
        with patch('src.datasets.data_files.thread_map') as mock_thread_map:
            mock_partial.return_value = Mock()
            mock_thread_map.return_value = [("etag123",)]
            
            _get_origin_metadata(data_files, download_config=download_config)
            
            # Verify partial was called with the function and download_config
            mock_partial.assert_called_once()

def test_special_characters_in_file_path():
    """Test that file paths with special characters are handled."""
    data_files = ["/path/to/file with spaces.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.7μs -> 21.1μs (2.67% faster)

def test_unicode_characters_in_file_path():
    """Test that file paths with unicode characters are handled."""
    data_files = ["/path/to/файл.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.9μs -> 21.1μs (3.74% faster)

def test_s3_file_with_etag():
    """Test that S3 files return ETag metadata."""
    data_files = ["s3://bucket/key.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("s3etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.5μs -> 21.0μs (2.40% faster)

def test_gcs_file_with_etag():
    """Test that GCS files return etag metadata."""
    data_files = ["gs://bucket/key.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("gcsetag123",)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.6μs -> 21.2μs (2.18% faster)

def test_file_path_with_protocol_prefix():
    """Test that file paths with protocol prefixes are handled."""
    data_files = ["file:///path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.6μs -> 20.9μs (3.62% faster)

def test_hf_file_with_specific_revision():
    """Test that HF files with specific revisions return correct metadata."""
    data_files = ["hf://datasets/repo/main@specific_hash"]
    
    with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
        with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
            mock_single.return_value = ("repo", "specific_hash")
            mock_tqdm.return_value = data_files
            codeflash_output = _get_origin_metadata(data_files); result = codeflash_output

def test_very_long_file_path():
    """Test that very long file paths are handled."""
    long_path = "/path/" + "a" * 200 + "/file.txt"
    data_files = [long_path]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.5μs -> 21.7μs (0.656% slower)

def test_zero_max_workers():
    """Test that zero max_workers is passed through."""
    data_files = ["/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        _get_origin_metadata(data_files, max_workers=0) # 21.2μs -> 21.1μs (0.599% faster)
        
        call_kwargs = mock_thread_map.call_args[1]

def test_large_max_workers():
    """Test that large max_workers value is passed through."""
    data_files = ["/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        _get_origin_metadata(data_files, max_workers=1000) # 21.3μs -> 20.9μs (1.61% faster)
        
        call_kwargs = mock_thread_map.call_args[1]

def test_many_hf_files_sequential():
    """Test that 100 HF files are processed sequentially."""
    data_files = ["hf://repo/file_" + str(i) for i in range(100)]
    
    with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
        with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
            mock_single.side_effect = [("repo", "main") for _ in range(100)]
            mock_tqdm.return_value = data_files
            codeflash_output = _get_origin_metadata(data_files); result = codeflash_output

def test_many_local_files_with_threading():
    """Test that 100 local files are processed with threading."""
    data_files = ["/path/to/file_" + str(i) + ".txt" for i in range(100)]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag" + str(i),) for i in range(100)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.4μs -> 21.6μs (0.825% slower)

def test_large_batch_respects_max_workers():
    """Test that large batch respects max_workers limit."""
    data_files = ["/path/to/file_" + str(i) + ".txt" for i in range(50)]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag" + str(i),) for i in range(50)]
        _get_origin_metadata(data_files, max_workers=16) # 21.4μs -> 21.3μs (0.644% faster)
        
        call_kwargs = mock_thread_map.call_args[1]

def test_mixed_metadata_types_in_large_batch():
    """Test that mixed metadata types are correctly returned in large batch."""
    data_files = (
        ["hf://repo/main"] * 20 +  # HF files
        ["/path/to/file_" + str(i) + ".txt" for i in range(30)] +  # Local files
        ["s3://bucket/key_" + str(i) for i in range(20)]  # S3 files
    )
    
    with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
        with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
            # Since not all are HF, will use thread_map
            pass
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        results = (
            [("repo", "main")] * 20 +
            [("etag" + str(i),) for i in range(30)] +
            [("s3etag" + str(i),) for i in range(20)]
        )
        mock_thread_map.return_value = results
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 22.8μs -> 22.4μs (1.75% faster)

def test_consistency_of_results_across_runs():
    """Test that multiple calls with same inputs return consistent results."""
    data_files = ["hf://repo/main", "/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("repo", "main"), ("etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result1 = codeflash_output # 21.8μs -> 21.4μs (1.64% faster)
        
        mock_thread_map.return_value = [("repo", "main"), ("etag123",)]
        codeflash_output = _get_origin_metadata(data_files); result2 = codeflash_output # 10.7μs -> 10.4μs (3.58% faster)

def test_large_download_config_with_many_storage_options():
    """Test that large DownloadConfig with many storage options is handled."""
    storage_opts = {f"opt_{i}": f"value_{i}" for i in range(50)}
    download_config = DownloadConfig(storage_options=storage_opts)
    data_files = ["/path/to/file.txt"]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag123",)]
        codeflash_output = _get_origin_metadata(data_files, download_config=download_config); result = codeflash_output # 21.8μs -> 20.6μs (5.88% faster)

def test_many_sequential_calls():
    """Test that many sequential calls work correctly."""
    data_files = ["hf://repo/file"]
    
    with patch('src.datasets.data_files.hf_tqdm') as mock_tqdm:
        with patch('src.datasets.data_files._get_single_origin_metadata') as mock_single:
            mock_single.return_value = ("repo", "main")
            mock_tqdm.return_value = data_files
            
            for _ in range(50):
                codeflash_output = _get_origin_metadata(data_files); result = codeflash_output

def test_extremely_large_file_list():
    """Test handling of very large file lists (500 files)."""
    data_files = ["/path/to/file_" + str(i) + ".txt" for i in range(500)]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag" + str(i),) for i in range(500)]
        codeflash_output = _get_origin_metadata(data_files); result = codeflash_output # 21.9μs -> 21.6μs (1.83% faster)

def test_progress_bar_settings_in_large_batch():
    """Test that progress bar settings are correct for large batches."""
    data_files = ["/path/to/file_" + str(i) + ".txt" for i in range(200)]
    
    with patch('src.datasets.data_files.thread_map') as mock_thread_map:
        mock_thread_map.return_value = [("etag" + str(i),) for i in range(200)]
        _get_origin_metadata(data_files) # 21.8μs -> 21.0μs (3.75% faster)
        
        call_kwargs = mock_thread_map.call_args[1]

def test_performance_with_increasing_file_counts():
    """Test that function scales reasonably with increasing file counts."""
    for count in [10, 50, 100]:
        data_files = [f"/path/to/file_{i}.txt" for i in range(count)]
        
        with patch('src.datasets.data_files.thread_map') as mock_thread_map:
            mock_thread_map.return_value = [(f"etag{i}",) for i in range(count)]
            codeflash_output = _get_origin_metadata(data_files); result = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-_get_origin_metadata-mlcri56h and push.

Codeflash Static Badge

This optimization achieves a **5% runtime improvement** by introducing thread-safe caching that eliminates redundant filesystem and network operations when processing data files.

## Key Optimizations

**1. Origin Metadata Caching**
The optimization adds `_origin_metadata_cache` keyed by `(data_file, token)` to store previously computed metadata. When `_get_single_origin_metadata` is called with the same file path and authentication token, it returns the cached result immediately instead of:
- Re-parsing file paths via `_prepare_path_and_storage_options`
- Re-initializing filesystem objects through `url_to_fs` (73% of original function time)
- Re-fetching remote file info via `fs.info()` or `fs.resolve_path()`

**2. HfFileSystem Instance Reuse**
A second cache `_hffs_cache` stores `HfFileSystem` instances per token. When processing multiple files from the same Hugging Face endpoint with the same authentication, the code reuses a single connection instead of creating new `HfFileSystem` objects repeatedly. This reduces HTTP handshake overhead and API call latency.

**3. Loop Optimization**
Replaced the implicit `return` in the original's `for-else` construct with an explicit `break` statement, avoiding unnecessary loop iterations after finding the first matching metadata key (`ETag`, `etag`, or `mtime`).

## Why This Works

From the line profiler, `url_to_fs()` consumed 73.3% of `_get_single_origin_metadata`'s time in the original code. The cache provides O(1) lookups that bypass this expensive operation entirely for repeated files. Thread safety via `_cache_lock` ensures correctness when `_get_origin_metadata` uses `thread_map` for parallel processing of non-HF files.

## Impact on Workloads

Based on `function_references`, this optimization benefits workflows where:
- **`DataFilesList.from_patterns()`** and **`DataFilesPatternsList.resolve()`** repeatedly process overlapping file sets or patterns that resolve to the same files
- Multiple patterns share files (e.g., train/validation splits from the same repository)
- Large datasets with many files sharing the same Hugging Face token/endpoint

The annotated tests show the optimization excels when:
- **Same files are resolved multiple times** (test_consistency_of_results_across_runs: 3.58% faster on second call)
- **Large batches of files** (test_large_scale_many_files: 109% faster for 50 files, test_extremely_large_file_list: 1.83% faster for 500 files)
- **All-HF file lists** where caching amplifies benefits (test_all_hf_paths_skip_thread_map: 1071% faster)

The 5% overall speedup reflects typical mixed workloads. Caching provides larger gains when file lists contain duplicates or when datasets are loaded repeatedly during development/experimentation.
@codeflash-ai codeflash-ai bot requested a review from aseembits93 February 7, 2026 20:23
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash labels Feb 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants