diff --git a/tests/test_api_introspection.py b/tests/test_api_introspection.py new file mode 100644 index 00000000..08637275 --- /dev/null +++ b/tests/test_api_introspection.py @@ -0,0 +1,97 @@ +""" +Tests for get/set number of threads API introspection. +""" + +import sys +from threading import local as threadlocal + +import pytest + +from threadpoolctl import ( + _ThreadLimitScope, + _determine_thread_limit_scope, + ThreadpoolController, +) + +# Make sure we have some BLAS libraries loaded: +from . import utils as _ + + +class FakeThreadLocalAPI(threadlocal): + """Thread-local num threads setting API.""" + + def get(self) -> int: + return getattr(self, "num_threads", 17) + + def set(self, n: int) -> None: + self.num_threads = n + + +class FakeProcesswideAPI: + """Process-wide num threads setting API.""" + + def __init__(self, num_threads: int): + self.num_threads = num_threads + + def get(self) -> int: + return self.num_threads + + def set(self, n: int) -> None: + self.num_threads = n + + +def test_determine_thread_limit_scope_thread_local() -> None: + """ + Check ``_determine_thread_limit_scope()`` can correctly diagnose a trivial + thread-local implementation. + """ + api = FakeThreadLocalAPI() + assert ( + _determine_thread_limit_scope(api.get, api.set) + == _ThreadLimitScope.CURRENT_THREAD + ) + + +@pytest.mark.parametrize("default", [1, 17]) +def test_determine_thread_limit_scope_processwide(default: int) -> None: + """ + Check ``_determine_thread_limit_scope()`` can correctly diagnose a trivial + process-wide implementation. + """ + api = FakeProcesswideAPI(default) + assert _determine_thread_limit_scope(api.get, api.set) == _ThreadLimitScope.PROCESS + + +@pytest.mark.skipif( + sys.platform != "linux", reason="We only hardcoded Linux-specific behavior" +) +@pytest.mark.parametrize( + ["select_filter", "expected_thread_limit_scope"], + [ + ( + {"internal_api": "openblas", "threading_layer": "pthreads"}, + _ThreadLimitScope.PROCESS, + ), + ( + {"user_api": "openmp"}, + _ThreadLimitScope.CURRENT_THREAD, + ), + ], +) +def test_api_scope( + select_filter: dict[str, str], expected_thread_limit_scope: str +) -> None: + """ + Check ``_determine_thread_limit_scope()`` against libraries with known + properties, to make sure it detects them correctly. The test is intended + to be of the function's behavior, not of the libraries. + """ + controller = ThreadpoolController().select(**select_filter) + if not controller.lib_controllers: + pytest.skip(f"{select_filter} controller not found") + + for lib in controller.lib_controllers: + assert ( + _determine_thread_limit_scope(lib.get_num_threads, lib.set_num_threads) + == expected_thread_limit_scope + ) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py index ea27cc2b..46a2c28b 100644 --- a/tests/test_threadpoolctl.py +++ b/tests/test_threadpoolctl.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import os import pytest @@ -570,7 +572,7 @@ def test_command_line_command_flag(): ) cli_info = json.loads(output.decode("utf-8")) - this_process_info = threadpool_info() + this_process_info = threadpool_info(extra_info=True) for lib_info in cli_info: assert lib_info in this_process_info @@ -596,7 +598,7 @@ def test_command_line_import_flag(): ) cli_info = json.loads(result.stdout) - this_process_info = threadpool_info() + this_process_info = threadpool_info(extra_info=True) for lib_info in cli_info: assert lib_info in this_process_info diff --git a/threadpoolctl.py b/threadpoolctl.py index ceed5b88..6632d8c7 100644 --- a/threadpoolctl.py +++ b/threadpoolctl.py @@ -17,12 +17,14 @@ import ctypes import itertools import textwrap -from typing import final +from threading import Thread +from typing import Callable, final import warnings from ctypes.util import find_library from abc import ABC, abstractmethod from functools import lru_cache from contextlib import ContextDecorator +from enum import Enum, auto __version__ = "3.7.0.dev0" __all__ = [ @@ -69,6 +71,91 @@ class _dl_phdr_info(ctypes.Structure): _RTLD_NOLOAD = ctypes.DEFAULT_MODE +class _ThreadLimitScope(Enum): + """ + What scope does the API affect. + """ + + # Using the API sets a limit only on the current thread. + CURRENT_THREAD = auto() + # Using the API sets a limit for every thread in the process; whether or + # not it's a shared process-wide pool or per-thread limit needs to be + # determined some other way. + PROCESS = auto() + # Something else, unexpected; perhaps another variant, perhaps information + # can't be determined under the current configuration. + UNKNOWN = auto() + + +def _determine_thread_limit_scope( + get_n_threads: Callable[[], int], set_n_threads: Callable[[int], None] +) -> _ThreadLimitScope: + """ + Run some experiments to determine the scope of the given get/set API. + + This function might not work if you only have one core available. + + This function might not work if you set a limit on a library with an + environment variable. + + The function works by changing the number of threads in loaded controllers, + which can be a process-wide change. As such, it is not always thread-safe. + An attempt will be made to restore all settings to their previous state, + but the result may be subtly different, e.g. if "unset" has different + semantics than "set to the default returned value". + """ + previous = get_n_threads() + + # Some plausible constraints we need to keep in mind: + # + # 1. The API might not allow setting more than the number of (available, or + # physical) cores. + # 2. Some hard limit on number of threads. + try: + # Choose a desired number of threads that is different than the current + # number, and hopefully achievable under the current configuration: + if previous < 2: + expected = 2 + else: + # It's 2 or more, so shrink it slightly: + expected = previous - 1 + + thread_result = [] + + def get_and_set() -> None: + set_n_threads(expected) + thread_result.append(get_n_threads()) + + thread = Thread(target=get_and_set) + thread.start() + thread.join() + + # First, getting in the same thread as a set should always give same + # number, if it's a number in a reasonable range. A possible exception + # fo failing this is if the number of thread is limited by available + # CPU, and only one CPU is available. In that case we can't empirically + # determine how the API works. We try to not reach that point here, but + # you can imagine a thread pool implementation that is aware of + # cgroups, in which case a Docker container limited to one core will + # pass the safety check at the start of the function. Perhaps + # cpu_count() from loky should be moved into this package... + if thread_result != [expected]: + return _ThreadLimitScope.UNKNOWN + + # Now, check this thread: + if get_n_threads() == expected: + # Setting modified this thread's results too: + return _ThreadLimitScope.PROCESS + elif get_n_threads() == previous: + # Setting modified the other thread, but not this one: + return _ThreadLimitScope.CURRENT_THREAD + else: + # No idea what's going on: + return _ThreadLimitScope.UNKNOWN + finally: + set_n_threads(previous) + + class LibController(ABC): """Abstract base class for the individual library controllers @@ -116,15 +203,28 @@ def __init__(self, *, filepath=None, prefix=None, parent=None): self.version = self.get_version() self.set_additional_attributes() - def info(self): - """Return relevant info wrapped in a dict""" + def info(self, extra_info: bool = False): + """Return relevant info wrapped in a dict. + + Parameters + ---------- + extra_info : bool + + Include extra fields which requires more intrusive actions to + obtain. + """ hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix") - return { + result = { "user_api": self.user_api, "internal_api": self.internal_api, "num_threads": self.num_threads, **{k: v for k, v in vars(self).items() if k not in hidden_attrs}, } + if extra_info: + result["thread_limit_scope"] = _determine_thread_limit_scope( + self.get_num_threads, self.set_num_threads + ).name.lower() + return result def set_additional_attributes(self): """Set additional attributes meant to be exposed in the info dict""" @@ -549,7 +649,7 @@ def _realpath(filepath): @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) -def threadpool_info(): +def threadpool_info(extra_info: bool = False): """Return the maximal number of threads for each detected library. Return a list with all the supported libraries that have been found. Each @@ -563,8 +663,16 @@ def threadpool_info(): - "num_threads": the current thread limit. In addition, each library may contain internal_api specific entries. + + Parameters + ---------- + extra_info : bool + Include extra fields which requires more intrusive actions to obtain. + + - "thread_limit_scope": When setting the number of threads, what is + affected. Possible values are "process", "current_thread". """ - return ThreadpoolController().info() + return ThreadpoolController().info(extra_info) class _ThreadpoolLimiter: @@ -824,9 +932,19 @@ def _from_controllers(cls, lib_controllers): new_controller.lib_controllers = lib_controllers return new_controller - def info(self): - """Return lib_controllers info as a list of dicts""" - return [lib_controller.info() for lib_controller in self.lib_controllers] + def info(self, extra_info: bool = False): + """Return lib_controllers info as a list of dicts. + + Parameters + ---------- + extra_info : bool + Include extra fields which requires more intrusive actions to + obtain. + """ + return [ + lib_controller.info(extra_info=extra_info) + for lib_controller in self.lib_controllers + ] def select(self, **kwargs): """Return a ThreadpoolController containing a subset of its current @@ -1290,7 +1408,7 @@ def _main(): if options.command: exec(options.command) - print(json.dumps(threadpool_info(), indent=2)) + print(json.dumps(threadpool_info(extra_info=True), indent=2)) if __name__ == "__main__":