-
Notifications
You must be signed in to change notification settings - Fork 37
Thread limit introspection API, part 1: API scope #213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
28bd246
128c4ef
8637248
6946867
31bbcf5
f27e6f2
0d4dde5
60dd11c
2c985b6
bb7b528
7fee2dc
1c6ae79
9ad8382
d306c55
8740d13
12a9d96
95d77a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| """ | ||
| Tests for get/set number of threads API introspection. | ||
| """ | ||
|
|
||
| import sys | ||
| from threading import local as threadlocal | ||
|
|
||
| import pytest | ||
|
|
||
| from threadpoolctl import ( | ||
| _ThreadLimitScope, | ||
| _determine_thread_limit_scope, | ||
| ThreadpoolController, | ||
| ) | ||
|
|
||
| # Make sure we have some BLAS libraries loaded: | ||
| from . import utils as _ | ||
|
|
||
|
|
||
| class FakeThreadLocalAPI(threadlocal): | ||
| """Thread-local num threads setting API.""" | ||
|
|
||
| def get(self) -> int: | ||
| return getattr(self, "num_threads", 17) | ||
|
|
||
| def set(self, n: int) -> None: | ||
| self.num_threads = n | ||
|
|
||
|
|
||
| class FakeProcesswideAPI: | ||
| """Process-wide num threads setting API.""" | ||
|
|
||
| def __init__(self, num_threads: int): | ||
| self.num_threads = num_threads | ||
|
|
||
| def get(self) -> int: | ||
| return self.num_threads | ||
|
|
||
| def set(self, n: int) -> None: | ||
| self.num_threads = n | ||
|
|
||
|
|
||
| def test_determine_thread_limit_scope_thread_local() -> None: | ||
| """ | ||
| Check ``_determine_thread_limit_scope()`` can correctly diagnose a trivial | ||
| thread-local implementation. | ||
| """ | ||
| api = FakeThreadLocalAPI() | ||
| assert ( | ||
| _determine_thread_limit_scope(api.get, api.set) | ||
| == _ThreadLimitScope.CURRENT_THREAD | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("default", [1, 17]) | ||
| def test_determine_thread_limit_scope_processwide(default: int) -> None: | ||
| """ | ||
| Check ``_determine_thread_limit_scope()`` can correctly diagnose a trivial | ||
| process-wide implementation. | ||
| """ | ||
| api = FakeProcesswideAPI(default) | ||
| assert _determine_thread_limit_scope(api.get, api.set) == _ThreadLimitScope.PROCESS | ||
|
|
||
|
|
||
| @pytest.mark.skipif( | ||
| sys.platform != "linux", reason="We only hardcoded Linux-specific behavior" | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| ["select_filter", "expected_thread_limit_scope"], | ||
| [ | ||
| ( | ||
| {"internal_api": "openblas", "threading_layer": "pthreads"}, | ||
| _ThreadLimitScope.PROCESS, | ||
| ), | ||
| ( | ||
| {"user_api": "openmp"}, | ||
| _ThreadLimitScope.CURRENT_THREAD, | ||
| ), | ||
| ], | ||
| ) | ||
| def test_api_scope( | ||
| select_filter: dict[str, str], expected_thread_limit_scope: str | ||
| ) -> None: | ||
| """ | ||
| Check ``_determine_thread_limit_scope()`` against libraries with known | ||
| properties, to make sure it detects them correctly. The test is intended | ||
| to be of the function's behavior, not of the libraries. | ||
| """ | ||
| controller = ThreadpoolController().select(**select_filter) | ||
| if not controller.lib_controllers: | ||
| pytest.skip(f"{select_filter} controller not found") | ||
|
|
||
| for lib in controller.lib_controllers: | ||
| assert ( | ||
| _determine_thread_limit_scope(lib.get_num_threads, lib.set_num_threads) | ||
| == expected_thread_limit_scope | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,12 +17,14 @@ | |
| import ctypes | ||
| import itertools | ||
| import textwrap | ||
| from typing import final | ||
| from threading import Thread | ||
| from typing import Callable, final | ||
| import warnings | ||
| from ctypes.util import find_library | ||
| from abc import ABC, abstractmethod | ||
| from functools import lru_cache | ||
| from contextlib import ContextDecorator | ||
| from enum import Enum, auto | ||
|
|
||
| __version__ = "3.7.0.dev0" | ||
| __all__ = [ | ||
|
|
@@ -69,6 +71,91 @@ class _dl_phdr_info(ctypes.Structure): | |
| _RTLD_NOLOAD = ctypes.DEFAULT_MODE | ||
|
|
||
|
|
||
| class _ThreadLimitScope(Enum): | ||
| """ | ||
| What scope does the API affect. | ||
| """ | ||
|
|
||
| # Using the API sets a limit only on the current thread. | ||
| CURRENT_THREAD = auto() | ||
| # Using the API sets a limit for every thread in the process; whether or | ||
| # not it's a shared process-wide pool or per-thread limit needs to be | ||
| # determined some other way. | ||
| PROCESS = auto() | ||
| # Something else, unexpected; perhaps another variant, perhaps information | ||
| # can't be determined under the current configuration. | ||
| UNKNOWN = auto() | ||
|
|
||
|
|
||
| def _determine_thread_limit_scope( | ||
| get_n_threads: Callable[[], int], set_n_threads: Callable[[int], None] | ||
| ) -> _ThreadLimitScope: | ||
| """ | ||
| Run some experiments to determine the scope of the given get/set API. | ||
|
|
||
| This function might not work if you only have one core available. | ||
|
|
||
| This function might not work if you set a limit on a library with an | ||
| environment variable. | ||
|
|
||
| The function works by changing the number of threads in loaded controllers, | ||
| which can be a process-wide change. As such, it is not always thread-safe. | ||
| An attempt will be made to restore all settings to their previous state, | ||
| but the result may be subtly different, e.g. if "unset" has different | ||
| semantics than "set to the default returned value". | ||
| """ | ||
| previous = get_n_threads() | ||
|
|
||
| # Some plausible constraints we need to keep in mind: | ||
| # | ||
| # 1. The API might not allow setting more than the number of (available, or | ||
| # physical) cores. | ||
| # 2. Some hard limit on number of threads. | ||
| try: | ||
| # Choose a desired number of threads that is different than the current | ||
| # number, and hopefully achievable under the current configuration: | ||
| if previous < 2: | ||
| expected = 2 | ||
|
itamarst marked this conversation as resolved.
|
||
| else: | ||
| # It's 2 or more, so shrink it slightly: | ||
| expected = previous - 1 | ||
|
|
||
| thread_result = [] | ||
|
|
||
| def get_and_set() -> None: | ||
| set_n_threads(expected) | ||
| thread_result.append(get_n_threads()) | ||
|
|
||
| thread = Thread(target=get_and_set) | ||
| thread.start() | ||
| thread.join() | ||
|
|
||
| # First, getting in the same thread as a set should always give same | ||
| # number, if it's a number in a reasonable range. A possible exception | ||
| # fo failing this is if the number of thread is limited by available | ||
| # CPU, and only one CPU is available. In that case we can't empirically | ||
| # determine how the API works. We try to not reach that point here, but | ||
| # you can imagine a thread pool implementation that is aware of | ||
| # cgroups, in which case a Docker container limited to one core will | ||
| # pass the safety check at the start of the function. Perhaps | ||
| # cpu_count() from loky should be moved into this package... | ||
| if thread_result != [expected]: | ||
| return _ThreadLimitScope.UNKNOWN | ||
|
|
||
| # Now, check this thread: | ||
| if get_n_threads() == expected: | ||
| # Setting modified this thread's results too: | ||
| return _ThreadLimitScope.PROCESS | ||
| elif get_n_threads() == previous: | ||
| # Setting modified the other thread, but not this one: | ||
| return _ThreadLimitScope.CURRENT_THREAD | ||
| else: | ||
| # No idea what's going on: | ||
| return _ThreadLimitScope.UNKNOWN | ||
| finally: | ||
| set_n_threads(previous) | ||
|
|
||
|
|
||
| class LibController(ABC): | ||
| """Abstract base class for the individual library controllers | ||
|
|
||
|
|
@@ -116,15 +203,28 @@ def __init__(self, *, filepath=None, prefix=None, parent=None): | |
| self.version = self.get_version() | ||
| self.set_additional_attributes() | ||
|
|
||
| def info(self): | ||
| """Return relevant info wrapped in a dict""" | ||
| def info(self, extra_info: bool = False): | ||
| """Return relevant info wrapped in a dict. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| extra_info : bool | ||
|
|
||
| Include extra fields which requires more intrusive actions to | ||
| obtain. | ||
| """ | ||
| hidden_attrs = ("dynlib", "parent", "_symbol_prefix", "_symbol_suffix") | ||
| return { | ||
| result = { | ||
| "user_api": self.user_api, | ||
| "internal_api": self.internal_api, | ||
| "num_threads": self.num_threads, | ||
| **{k: v for k, v in vars(self).items() if k not in hidden_attrs}, | ||
| } | ||
| if extra_info: | ||
| result["thread_limit_scope"] = _determine_thread_limit_scope( | ||
| self.get_num_threads, self.set_num_threads | ||
| ).name.lower() | ||
| return result | ||
|
|
||
| def set_additional_attributes(self): | ||
| """Set additional attributes meant to be exposed in the info dict""" | ||
|
|
@@ -549,7 +649,7 @@ def _realpath(filepath): | |
|
|
||
|
|
||
| @_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) | ||
| def threadpool_info(): | ||
| def threadpool_info(extra_info: bool = False): | ||
| """Return the maximal number of threads for each detected library. | ||
|
|
||
| Return a list with all the supported libraries that have been found. Each | ||
|
|
@@ -563,8 +663,16 @@ def threadpool_info(): | |
| - "num_threads": the current thread limit. | ||
|
|
||
| In addition, each library may contain internal_api specific entries. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| extra_info : bool | ||
| Include extra fields which requires more intrusive actions to obtain. | ||
|
|
||
| - "thread_limit_scope": When setting the number of threads, what is | ||
| affected. Possible values are "process", "current_thread". | ||
| """ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When Then we can have tests to check that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we do so, we should probably rename
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See my comments about subprocess implementation elsewhere, I think this is probably unnecessary.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically I am thinking of this for now at least mostly as debugging and bug report tool, rather than something the API will use. |
||
| return ThreadpoolController().info() | ||
| return ThreadpoolController().info(extra_info) | ||
|
|
||
|
|
||
| class _ThreadpoolLimiter: | ||
|
|
@@ -824,9 +932,19 @@ def _from_controllers(cls, lib_controllers): | |
| new_controller.lib_controllers = lib_controllers | ||
| return new_controller | ||
|
|
||
| def info(self): | ||
| """Return lib_controllers info as a list of dicts""" | ||
| return [lib_controller.info() for lib_controller in self.lib_controllers] | ||
| def info(self, extra_info: bool = False): | ||
| """Return lib_controllers info as a list of dicts. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| extra_info : bool | ||
| Include extra fields which requires more intrusive actions to | ||
| obtain. | ||
| """ | ||
| return [ | ||
| lib_controller.info(extra_info=extra_info) | ||
| for lib_controller in self.lib_controllers | ||
| ] | ||
|
|
||
| def select(self, **kwargs): | ||
| """Return a ThreadpoolController containing a subset of its current | ||
|
|
@@ -1290,7 +1408,7 @@ def _main(): | |
| if options.command: | ||
| exec(options.command) | ||
|
|
||
| print(json.dumps(threadpool_info(), indent=2)) | ||
| print(json.dumps(threadpool_info(extra_info=True), indent=2)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking about it a bit more, we could make this inspection side effect free by calling this function in an isolated via
subprocess.runif we really wanted.The problem would be to make sure that the same native threadpool libraries that are loaded in the current process at the time of the inspection call are also loaded in the subprocess. To do so we could manually call
ctypes.CDLLbut that might be a bit brittle.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thinking out loud... also relevant to the "should we hardcode this if introspection isn't available" question...
My current thought is that introspection is a useful debugging tool, but not something that would actually be used. Whatever it does, the library is stuck with it. And so I think it's OK if it's just opt-in, mostly for CLI users doing bug reports or when doing performance debugging.
And so something simpler and minimalist seems sufficient.
I may be wrong, it may be that knowing which it is will be helpful. And if so we can maybe go for more elaborate approaches like subprocess later on.