Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions tests/test_api_introspection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Tests for get/set number of threads API introspection.
"""

import sys
from threading import local as threadlocal

import pytest

from threadpoolctl import (
_ThreadLimitScope,
_determine_thread_limit_scope,
ThreadpoolController,
)

# Make sure we have some BLAS libraries loaded:
from . import utils as _


class FakeThreadLocalAPI(threadlocal):
"""Thread-local num threads setting API."""

def get(self) -> int:
return getattr(self, "num_threads", 17)

def set(self, n: int) -> None:
self.num_threads = n


class FakeProcesswideAPI:
"""Process-wide num threads setting API."""

def __init__(self, num_threads: int):
self.num_threads = num_threads

def get(self) -> int:
return self.num_threads

def set(self, n: int) -> None:
self.num_threads = n


def test_determine_thread_limit_scope_thread_local() -> None:
"""
Check ``_determine_thread_limit_scope()`` can correctly diagnose a trivial
thread-local implementation.
"""
api = FakeThreadLocalAPI()
assert (
_determine_thread_limit_scope(api.get, api.set)
== _ThreadLimitScope.CURRENT_THREAD
)


@pytest.mark.parametrize("default", [1, 17])
def test_determine_thread_limit_scope_processwide(default: int) -> None:
"""
Check ``_determine_thread_limit_scope()`` can correctly diagnose a trivial
process-wide implementation.
"""
api = FakeProcesswideAPI(default)
assert _determine_thread_limit_scope(api.get, api.set) == _ThreadLimitScope.PROCESS


@pytest.mark.skipif(
sys.platform != "linux", reason="We only hardcoded Linux-specific behavior"
)
@pytest.mark.parametrize(
["select_filter", "expected_thread_limit_scope"],
[
(
{"internal_api": "openblas", "threading_layer": "pthreads"},
_ThreadLimitScope.PROCESS,
),
(
{"user_api": "openmp"},
_ThreadLimitScope.CURRENT_THREAD,
),
],
)
def test_api_scope(
select_filter: dict[str, str], expected_thread_limit_scope: str
) -> None:
"""
Check ``_determine_thread_limit_scope()`` against libraries with known
properties, to make sure it detects them correctly. The test is intended
to be of the function's behavior, not of the libraries.
"""
controller = ThreadpoolController().select(**select_filter)
if not controller.lib_controllers:
pytest.skip(f"{select_filter} controller not found")

for lib in controller.lib_controllers:
assert (
_determine_thread_limit_scope(lib.get_num_threads, lib.set_num_threads)
== expected_thread_limit_scope
)
6 changes: 4 additions & 2 deletions tests/test_threadpoolctl.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json
import os
import pytest
Expand Down Expand Up @@ -570,7 +572,7 @@ def test_command_line_command_flag():
)
cli_info = json.loads(output.decode("utf-8"))

this_process_info = threadpool_info()
this_process_info = threadpool_info(extra_info=True)
for lib_info in cli_info:
assert lib_info in this_process_info

Expand All @@ -596,7 +598,7 @@ def test_command_line_import_flag():
)
cli_info = json.loads(result.stdout)

this_process_info = threadpool_info()
this_process_info = threadpool_info(extra_info=True)
for lib_info in cli_info:
assert lib_info in this_process_info

Expand Down
138 changes: 128 additions & 10 deletions threadpoolctl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
import ctypes
import itertools
import textwrap
from typing import final
from threading import Thread
from typing import Callable, final
import warnings
from ctypes.util import find_library
from abc import ABC, abstractmethod
from functools import lru_cache
from contextlib import ContextDecorator
from enum import Enum, auto

__version__ = "3.7.0.dev0"
__all__ = [
Expand Down Expand Up @@ -69,6 +71,91 @@ class _dl_phdr_info(ctypes.Structure):
_RTLD_NOLOAD = ctypes.DEFAULT_MODE


class _ThreadLimitScope(Enum):
"""
What scope does the API affect.
"""

# Using the API sets a limit only on the current thread.
CURRENT_THREAD = auto()
# Using the API sets a limit for every thread in the process; whether or
# not it's a shared process-wide pool or per-thread limit needs to be
# determined some other way.
PROCESS = auto()
# Something else, unexpected; perhaps another variant, perhaps information
# can't be determined under the current configuration.
UNKNOWN = auto()


def _determine_thread_limit_scope(
get_n_threads: Callable[[], int], set_n_threads: Callable[[int], None]
) -> _ThreadLimitScope:
"""
Run some experiments to determine the scope of the given get/set API.

This function might not work if you only have one core available.

This function might not work if you set a limit on a library with an
environment variable.

The function works by changing the number of threads in loaded controllers,
which can be a process-wide change. As such, it is not always thread-safe.
An attempt will be made to restore all settings to their previous state,
but the result may be subtly different, e.g. if "unset" has different
semantics than "set to the default returned value".

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about it a bit more, we could make this inspection side effect free by calling this function in an isolated via subprocess.run if we really wanted.

The problem would be to make sure that the same native threadpool libraries that are loaded in the current process at the time of the inspection call are also loaded in the subprocess. To do so we could manually call ctypes.CDLL but that might be a bit brittle.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud... also relevant to the "should we hardcode this if introspection isn't available" question...

My current thought is that introspection is a useful debugging tool, but not something that would actually be used. Whatever it does, the library is stuck with it. And so I think it's OK if it's just opt-in, mostly for CLI users doing bug reports or when doing performance debugging.

And so something simpler and minimalist seems sufficient.

I may be wrong, it may be that knowing which it is will be helpful. And if so we can maybe go for more elaborate approaches like subprocess later on.

"""
previous = get_n_threads()

# Some plausible constraints we need to keep in mind:
#
# 1. The API might not allow setting more than the number of (available, or
# physical) cores.
# 2. Some hard limit on number of threads.
try:
# Choose a desired number of threads that is different than the current
# number, and hopefully achievable under the current configuration:
if previous < 2:
expected = 2
Comment thread
itamarst marked this conversation as resolved.
else:
# It's 2 or more, so shrink it slightly:
expected = previous - 1

thread_result = []

def get_and_set() -> None:
set_n_threads(expected)
thread_result.append(get_n_threads())

thread = Thread(target=get_and_set)
thread.start()
thread.join()

# First, getting in the same thread as a set should always give same
# number, if it's a number in a reasonable range. A possible exception
# fo failing this is if the number of thread is limited by available
# CPU, and only one CPU is available. In that case we can't empirically
# determine how the API works. We try to not reach that point here, but
# you can imagine a thread pool implementation that is aware of
# cgroups, in which case a Docker container limited to one core will
# pass the safety check at the start of the function. Perhaps
# cpu_count() from loky should be moved into this package...
if thread_result != [expected]:
return _ThreadLimitScope.UNKNOWN

# Now, check this thread:
if get_n_threads() == expected:
# Setting modified this thread's results too:
return _ThreadLimitScope.PROCESS
elif get_n_threads() == previous:
# Setting modified the other thread, but not this one:
return _ThreadLimitScope.CURRENT_THREAD
else:
# No idea what's going on:
return _ThreadLimitScope.UNKNOWN
finally:
set_n_threads(previous)


class LibController(ABC):
"""Abstract base class for the individual library controllers

Expand Down Expand Up @@ -116,15 +203,28 @@ def __init__(self, *, filepath=None, prefix=None, parent=None):
self.version = self.get_version()
self.set_additional_attributes()

def info(self):
"""Return relevant info wrapped in a dict"""
def info(self, extra_info: bool = False):
"""Return relevant info wrapped in a dict.

Parameters
----------
extra_info : bool

Include extra fields which requires more intrusive actions to
obtain.
"""
hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix")
return {
result = {
"user_api": self.user_api,
"internal_api": self.internal_api,
"num_threads": self.num_threads,
**{k: v for k, v in vars(self).items() if k not in hidden_attrs},
}
if extra_info:
result["thread_limit_scope"] = _determine_thread_limit_scope(
self.get_num_threads, self.set_num_threads
).name.lower()
return result

def set_additional_attributes(self):
"""Set additional attributes meant to be exposed in the info dict"""
Expand Down Expand Up @@ -549,7 +649,7 @@ def _realpath(filepath):


@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS)
def threadpool_info():
def threadpool_info(extra_info: bool = False):
"""Return the maximal number of threads for each detected library.

Return a list with all the supported libraries that have been found. Each
Expand All @@ -563,8 +663,16 @@ def threadpool_info():
- "num_threads": the current thread limit.

In addition, each library may contain internal_api specific entries.

Parameters
----------
extra_info : bool
Include extra fields which requires more intrusive actions to obtain.

- "thread_limit_scope": When setting the number of threads, what is
affected. Possible values are "process", "current_thread".
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When extra_info is False I would be in favor of retrieving the expected info from a statically hardcoded data based of known semantics for common BLAS and OpenMP implementation and return UNKNOWN otherwise.

Then we can have tests to check that threadpool_info(extra_info=True) always returns the same as threadpool_info(extra_info=False) on all the environments tested by our CI.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do so, we should probably rename extra_info to inspect_scope or something like that.

@itamarst itamarst Jul 1, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comments about subprocess implementation elsewhere, I think this is probably unnecessary.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically I am thinking of this for now at least mostly as debugging and bug report tool, rather than something the API will use.

return ThreadpoolController().info()
return ThreadpoolController().info(extra_info)


class _ThreadpoolLimiter:
Expand Down Expand Up @@ -824,9 +932,19 @@ def _from_controllers(cls, lib_controllers):
new_controller.lib_controllers = lib_controllers
return new_controller

def info(self):
"""Return lib_controllers info as a list of dicts"""
return [lib_controller.info() for lib_controller in self.lib_controllers]
def info(self, extra_info: bool = False):
"""Return lib_controllers info as a list of dicts.

Parameters
----------
extra_info : bool
Include extra fields which requires more intrusive actions to
obtain.
"""
return [
lib_controller.info(extra_info=extra_info)
for lib_controller in self.lib_controllers
]

def select(self, **kwargs):
"""Return a ThreadpoolController containing a subset of its current
Expand Down Expand Up @@ -1290,7 +1408,7 @@ def _main():
if options.command:
exec(options.command)

print(json.dumps(threadpool_info(), indent=2))
print(json.dumps(threadpool_info(extra_info=True), indent=2))


if __name__ == "__main__":
Expand Down
Loading