diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 0272304ee3..a4e8fabb51 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -1,6 +1,6 @@ """Collection of utilities to detect properties of the underlying architecture.""" -from functools import cached_property +from functools import cache, cached_property from subprocess import PIPE, Popen, DEVNULL, run from pathlib import Path import ctypes @@ -14,7 +14,7 @@ import psutil from devito.logger import warning -from devito.tools import as_tuple, all_equal, memoized_func +from devito.tools import as_tuple, all_equal __all__ = ['platform_registry', 'get_cpu_info', 'get_gpu_info', 'get_nvidia_cc', 'get_cuda_path', 'get_hip_path', 'check_cuda_runtime', 'get_m1_llvm_path', @@ -41,7 +41,7 @@ 'PVC', 'INTELGPUMAX', 'MAX1100', 'MAX1550'] -@memoized_func +@cache def get_cpu_info(): """Attempt CPU info autodetection.""" @@ -163,7 +163,7 @@ def get_cpu_brand(): return cpu_info -@memoized_func +@cache def get_gpu_info(): """Attempt GPU info autodetection.""" @@ -488,7 +488,7 @@ def parse_product_arch(): return None -@memoized_func +@cache def get_nvidia_cc(): libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') for libname in libnames: @@ -511,7 +511,7 @@ def get_nvidia_cc(): return 10*cc_major.value + cc_minor.value -@memoized_func +@cache def get_cuda_path(): # *** First try: via commonly used environment variables for i in ['CUDA_HOME', 'CUDA_ROOT']: @@ -531,7 +531,7 @@ def get_cuda_path(): return None -@memoized_func +@cache def get_advisor_path(): """ Detect if Intel Advisor is installed on the machine and return @@ -552,7 +552,7 @@ def get_advisor_path(): return path -@memoized_func +@cache def get_hip_path(): # *** First try: via commonly used environment variables for i in ['HIP_HOME']: @@ -563,7 +563,7 @@ def get_hip_path(): return None -@memoized_func +@cache def get_m1_llvm_path(language): # Check if Apple's llvm is installed (installable via Homebrew), which supports # OpenMP. @@ -595,7 +595,7 @@ def get_m1_llvm_path(language): return None -@memoized_func +@cache def check_cuda_runtime(): libnames = ('libcudart.so', 'libcudart.dylib', 'cudart.dll') for libname in libnames: @@ -623,7 +623,7 @@ def check_cuda_runtime(): warning("Unable to check compatibility of NVidia driver and runtime") -@memoized_func +@cache def lscpu(): try: p1 = Popen(['lscpu'], stdout=PIPE, stderr=PIPE) @@ -645,7 +645,7 @@ def lscpu(): return {} -@memoized_func +@cache def get_platform(): """Attempt Platform autodetection.""" @@ -1111,7 +1111,7 @@ def march(cls): return fallback -@memoized_func +@cache def node_max_mem_trans_nbytes(platform): """ Return the maximum memory transaction size in bytes for the underlying diff --git a/devito/arch/compiler.py b/devito/arch/compiler.py index 0dfeafae14..28d00e9a03 100644 --- a/devito/arch/compiler.py +++ b/devito/arch/compiler.py @@ -1,4 +1,4 @@ -from functools import partial +from functools import cache, partial from hashlib import sha1 from os import environ, path, makedirs from packaging.version import Version @@ -19,13 +19,12 @@ from devito.exceptions import CompilationError from devito.logger import debug, warning from devito.parameters import configuration -from devito.tools import (as_list, change_directory, filter_ordered, - memoized_func, make_tempdir) +from devito.tools import as_list, change_directory, filter_ordered, make_tempdir __all__ = ['sniff_mpi_distro', 'compiler_registry'] -@memoized_func +@cache def sniff_compiler_version(cc, allow_fail=False): """ Detect the compiler version. @@ -99,7 +98,7 @@ def sniff_compiler_version(cc, allow_fail=False): return ver -@memoized_func +@cache def sniff_mpi_distro(mpiexec): """ Detect the MPI version. @@ -117,7 +116,7 @@ def sniff_mpi_distro(mpiexec): return 'unknown' -@memoized_func +@cache def sniff_mpi_flags(mpicc='mpicc'): mpi_distro = sniff_mpi_distro('mpiexec') if mpi_distro != 'OpenMPI': @@ -131,7 +130,7 @@ def sniff_mpi_flags(mpicc='mpicc'): return compile_flags.split(), link_flags.split() -@memoized_func +@cache def call_capture_output(cmd): """ Memoize calls to codepy's `call_capture_output` to avoid leaking memory due diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index b0733dac56..068ebabb62 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -1,135 +1,217 @@ -from collections.abc import Callable, Hashable -from functools import lru_cache, partial -from itertools import tee -from typing import TypeVar +from collections.abc import Hashable, Iterator +from functools import lru_cache, partial, update_wrapper +from threading import RLock, local +from typing import Callable, Concatenate, Generic, ParamSpec, TypeVar -__all__ = ['memoized_func', 'memoized_meth', 'memoized_generator', 'CacheInstances'] +__all__ = ['memoized_meth', 'memoized_generator', 'CacheInstances'] -class memoized_func: - """ - Decorator. Caches a function's return value each time it is called. - If called later with the same arguments, the cached value is returned - (not reevaluated). This decorator may also be used on class methods, - but it will cache at the class level; to cache at the instance level, - use ``memoized_meth``. - Adapted from: :: +# Type variables for memoized method decorators +InstanceType = TypeVar('InstanceType', contravariant=True) +ParamsType = ParamSpec('ParamsType') +ReturnType = TypeVar('ReturnType', covariant=True) - https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize - """ - def __init__(self, func): - self.func = func - self.cache = {} - - def __call__(self, *args, **kw): - if not isinstance(args, Hashable): - # Uncacheable, a list, for instance. - # Better to not cache than blow up. - return self.func(*args, **kw) - key = (self.func, args, frozenset(kw.items())) - if key in self.cache: - return self.cache[key] - else: - value = self.func(*args, **kw) - self.cache[key] = value - return value - - def __repr__(self): - """Return the function's docstring.""" - return self.func.__doc__ - - def __get__(self, obj, objtype): - """Support instance methods.""" - return partial(self.__call__, obj) - - -class memoized_meth: +class memoized_meth(Generic[InstanceType, ParamsType, ReturnType]): + """ + Decorator for a cached instance method. There is one cache per thread stored + on the object instance itself. """ - Decorator. Cache the return value of a class method. - - Unlike ``memoized_func``, the return value of a given method invocation - will be cached on the instance whose method was invoked. All arguments - passed to a method decorated with memoize must be hashable. - - If a memoized method is invoked directly on its class the result will not - be cached. Instead the method will be invoked like a static method: :: - class Obj: - @memoize - def add_to(self, arg): - return self + arg - Obj.add_to(1) # not enough arguments - Obj.add_to(1, 2) # returns 3, result is not cached + def __init__(self, meth: Callable[Concatenate[InstanceType, ParamsType], + ReturnType]) -> None: + self._meth = meth + self._lock = RLock() # Lock to safely initialize the thread-local object + update_wrapper(self, self._meth) - Adapted from: :: + def __get__(self, obj: InstanceType, cls: type[InstanceType] | None = None) \ + -> Callable[ParamsType, ReturnType]: + """ + Binds the memoized method to an instance. + """ + return partial(self, obj) - code.activestate.com/recipes/577452-a-memoize-decorator-for-instance-methods/ - """ + def _get_cache(self, obj: InstanceType) -> dict[Hashable, ReturnType]: + """ + Retrieves the thread-local cache for the given object instance, initializing + it if necessary. + """ + # Try-catch is theoretically faster on the happy path + _local: local + try: + # Attempt to access the thread-local data + _local = obj._memoized_meth__local - def __init__(self, func): - self.func = func + # If the cache doesn't exist, initialize it + except AttributeError: + with self._lock: + # Check again in case another thread initialized outside the lock + if not hasattr(obj, '_memoized_meth__local'): + # Initialize the local data if it doesn't exist + obj._memoized_meth__local = local() - def __get__(self, obj, objtype=None): - if obj is None: - return self.func - return partial(self, obj) + # Get the thread-local data + _local = obj._memoized_meth__local - def __call__(self, *args, **kw): - if not isinstance(args, Hashable): - # Uncacheable, a list, for instance. - # Better to not cache than blow up. - return self.func(*args) - obj = args[0] + # Local data is initialized; create or retrieve the cache try: - cache = obj.__cache_meth + return _local.cache except AttributeError: - cache = obj.__cache_meth = {} - key = (self.func, args[1:], frozenset(kw.items())) + _local.cache = {} + return _local.cache + + def __call__(self, obj: InstanceType, + *args: ParamsType.args, **kwargs: ParamsType.kwargs) -> ReturnType: + """ + Invokes the memoized method, caching the result if it hasn't been evaluated yet. + """ + # Get the local cache for the object instance + cache = self._get_cache(obj) + key = (self._meth, args, frozenset(kwargs.items())) try: + # Try to retrieve the cached value res = cache[key] except KeyError: - res = cache[key] = self.func(*args, **kw) + # If not cached, compute the value + res = cache[key] = self._meth(obj, *args, **kwargs) + return res -class memoized_generator: +# Describes the type of element yielded by a cached iterator +YieldType = TypeVar('YieldType', covariant=True) + +class SafeTee(Iterator[YieldType]): """ - Decorator. Cache the return value of an instance generator method. + A thread-safe version of `itertools.tee` that allows multiple iterators to safely + share the same buffer. + + In theory, this comes at a cost to performance of iterating elements that haven't + yet been generated, as `itertools.tee` is implemented in C (i.e. is fast) but we + need to buffer (and lock) in Python instead. + + However, the lock is not needed for elements that have already been buffered, + allowing for concurrent iteration after the generator is initially consumed. """ + def __init__(self, source_iter: Iterator[YieldType], + buffer: list[YieldType] = None, lock: RLock = None) \ + -> None: + # If no buffer/lock are provided, this is a parent iterator + self._source_iter = source_iter + self._buffer = buffer if buffer is not None else [] + self._lock = lock if lock is not None else RLock() + self._next = 0 + + def __iter__(self) -> Iterator[YieldType]: + return self + + def __next__(self) -> YieldType: + """ + Safely retrieves the buffer if available, or generates the next element + from the source iterator if not. + """ + # Retry concurrent element access until we can return a value + while True: + if self._next < len(self._buffer): + # If we have another buffered element, return it + result = self._buffer[self._next] + self._next += 1 + + return result + + # Otherwise, we may need to generate a new element + with self._lock: + if self._next < len(self._buffer): + # Another thread has already generated the next element; retry + continue + + # Generate the next element from the source iterator + try: + # Try to get the next element from the source iterator + result = next(self._source_iter) + self._buffer.append(result) + self._next += 1 + return result + except StopIteration: + # The source iterator has been exhausted + raise + + def __copy__(self) -> 'SafeTee': + return SafeTee(self._source_iter, self._buffer, self._lock) + + def tee(self) -> Iterator[YieldType]: + """ + Creates a new iterator that shares the same buffer and lock. + """ + return self.__copy__() - def __init__(self, func): - self.func = func - def __repr__(self): - """Return the function's docstring.""" - return self.func.__doc__ +class memoized_generator(Generic[InstanceType, ParamsType, YieldType]): + """ + Decorator for a cached instance generator method. The initial call to the generator + will block and return a thread-safe version of `itertools.tee` that allows for + concurrent iteration. + """ + + def __init__(self, meth: Callable[Concatenate[InstanceType, ParamsType], + Iterator[YieldType]]) -> None: + self._meth = meth + self._lock = RLock() # Lock for initial generator calls + update_wrapper(self, self._meth) - def __get__(self, obj, objtype=None): - if obj is None: - return self.func + def __get__(self, obj: InstanceType, cls: type[InstanceType] | None = None) \ + -> Callable[ParamsType, Iterator[YieldType]]: + """ + Binds the memoized method to an instance. + """ return partial(self, obj) - def __call__(self, *args, **kwargs): - if not isinstance(args, Hashable): - # Uncacheable, a list, for instance. - # Better to not cache than blow up. - return self.func(*args) - obj = args[0] + def _get_cache(self, obj: InstanceType) -> dict[Hashable, SafeTee[YieldType]]: + """ + Retrieves the generator cache for the given object instance, initializing + it if necessary. + """ + # Try-catch is theoretically faster on the happy path try: - cache = obj.__cache_gen + # Attempt to access the cache directly + return obj._memoized_generator__cache + + # If the cache doesn't exist, initialize it except AttributeError: - cache = obj.__cache_gen = {} - key = (self.func, args[1:], frozenset(kwargs.items())) - it = cache[key] if key in cache else self.func(*args, **kwargs) - cache[key], result = tee(it) - return result + with self._lock: + # Check again in case another thread initialized outside the lock + if not hasattr(obj, '_memoized_generator__cache'): + # Initialize the cache if it doesn't exist + obj._memoized_generator__cache = {} + + # Return the cache + return obj._memoized_generator__cache + + def __call__(self, obj: InstanceType, + *args: ParamsType.args, **kwargs: ParamsType.kwargs) \ + -> Iterator[YieldType]: + """ + Invokes the memoized generator, caching a SafeTee if it hasn't been created yet. + """ + # Get the local cache for the object instance + cache = self._get_cache(obj) + key = (self._meth, args, frozenset(kwargs.items())) + try: + # Try to retrieve the cached value + res = cache[key] + except KeyError: + # If not cached, compute the value + source_iter = self._meth(obj, *args, **kwargs) + res = cache[key] = SafeTee(source_iter) + + return res.tee() # Describes the type of a subclass of CacheInstances -InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True) +CachedInstanceType = TypeVar('CachedInstanceType', + bound='CacheInstances', covariant=True) class CacheInstancesMeta(type): @@ -139,14 +221,14 @@ class CacheInstancesMeta(type): _cached_types: set[type['CacheInstances']] = set() - def __init__(cls: type[InstanceType], *args) -> None: # type: ignore + def __init__(cls: type[CachedInstanceType], *args) -> None: # type: ignore super().__init__(*args) # Register the cached type CacheInstancesMeta._cached_types.add(cls) - def __call__(cls: type[InstanceType], # type: ignore - *args, **kwargs) -> InstanceType: + def __call__(cls: type[CachedInstanceType], # type: ignore + *args, **kwargs) -> CachedInstanceType: if cls._instance_cache is None: maxsize = cls._instance_cache_size cls._instance_cache = lru_cache(maxsize=maxsize)(super().__call__) diff --git a/devito/types/basic.py b/devito/types/basic.py index 8c7e960fb2..f8acfa69c2 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -16,7 +16,7 @@ from devito.tools import (Pickable, as_tuple, dtype_to_ctype, frozendict, memoized_meth, sympy_mutex, CustomDtype) from devito.types.args import ArgProvider -from devito.types.caching import Cached, Uncached +from devito.types.caching import CacheManager, Cached, Uncached from devito.types.lazy import Evaluable from devito.types.utils import DimensionTuple @@ -559,24 +559,31 @@ def _cache_key(cls, *args, **kwargs): def __new__(cls, *args, **kwargs): assumptions, kwargs = cls._filter_assumptions(**kwargs) key = cls._cache_key(*args, **{**assumptions, **kwargs}) - obj = cls._cache_get(key) + # Initial cache lookup (not locked) + obj = cls._cache_get(key) if obj is not None: return obj - # Not in cache. Create a new Symbol via sympy.Symbol - args = list(args) - name = kwargs.pop('name', None) or args.pop(0) - newobj = cls.__xnew__(cls, name, **assumptions) + # Lock against the symbol cache and double-check the cache + with CacheManager.lock(): + obj = cls._cache_get(key) + if obj is not None: + return obj - # Initialization - newobj._dtype = cls.__dtype_setup__(**kwargs) - newobj.__init_finalize__(name, *args, **kwargs) + # Not in cache. Create a new Symbol via sympy.Symbol + args = list(args) + name = kwargs.pop('name', None) or args.pop(0) + newobj = cls.__xnew__(cls, name, **assumptions) - # Store new instance in symbol cache - Cached.__init__(newobj, key) + # Initialization + newobj._dtype = cls.__dtype_setup__(**kwargs) + newobj.__init_finalize__(name, *args, **kwargs) - return newobj + # Store new instance in symbol cache + Cached.__init__(newobj, key) + + return newobj __hash__ = Cached.__hash__ diff --git a/devito/types/caching.py b/devito/types/caching.py index 948ae09e88..37b0bd7400 100644 --- a/devito/types/caching.py +++ b/devito/types/caching.py @@ -1,4 +1,5 @@ import gc +from threading import RLock import weakref import sympy @@ -10,6 +11,7 @@ __all__ = ['Cached', 'Uncached', '_SymbolCache', 'CacheManager'] _SymbolCache = {} +_cache_lock = RLock() """The symbol cache.""" @@ -76,8 +78,10 @@ def _cache_get(cls, key): obj = obj_cached() if obj is None: # Cleanup _SymbolCache (though practically unnecessary) - # does not fail if it's already gone - _SymbolCache.pop(key, None) + with _cache_lock: + # Ensure another thread hasn't replaced the ref we're evicting + if _SymbolCache.get(key) is obj_cached: + _SymbolCache.pop(key, None) return None else: return obj @@ -196,16 +200,21 @@ def clear(cls, force=True): # We won't call gc.collect() this time cls.ncalls_w_force_false += 1 - for key in cache_copied: - obj = _SymbolCache.get(key) - if obj is None: - # deleted by another thread since we took the copy - continue - if obj() is None: - # (key could be removed in another thread since get() above) - _SymbolCache.pop(key, None) + for key, obj_cached in cache_copied.items(): + if obj_cached() is None: + with _cache_lock: + # Check if our snapshot of the cached object is still live + if _SymbolCache.get(key) is obj_cached: + _SymbolCache.pop(key, None) # Maybe trigger garbage collection if force: del cache_copied gc.collect() + + @staticmethod + def lock(): + """ + Returns the global symbol cache lock for atomic construction. + """ + return _cache_lock diff --git a/tests/test_tools.py b/tests/test_tools.py index e53741086d..477fac6a4d 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1,3 +1,5 @@ +from concurrent.futures import ThreadPoolExecutor +from threading import RLock import numpy as np import pytest from sympy.abc import a, b, c, d, e @@ -7,7 +9,7 @@ from devito import Operator, Eq from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort, filter_ordered, transitive_closure, UnboundTuple, - CacheInstances) + CacheInstances, memoized_meth, memoized_generator) from devito.types.basic import Symbol @@ -209,3 +211,156 @@ def __init__(self, value: int): # Cache should be cleared after Operator construction cache_size = Object._instance_cache.cache_info()[-1] assert cache_size == 0 + + +class TestMemoizedMethods: + + def test_memoized_meth(self): + """ + Tests basic functionality of memoized_meth + """ + class Object: + def __init__(self): + self.misses = 0 + + @memoized_meth + def compute(self, x): + self.misses += 1 + return x * 2 + + obj = Object() + obj.compute(2) + obj.compute(4) + assert obj.compute(2) == 4 + assert obj.compute(4) == 8 + assert obj.misses == 2 # Only two unique calls + + def test_unhashable_args(self): + """ + Tests that memoized_meth raises an error for unhashable arguments. + """ + class Object: + def __init__(self): + self.misses = 0 + + @memoized_meth + def compute(self, x: list[int]): + self.misses += 1 + return sum(x) + + obj = Object() + with pytest.raises(TypeError): + obj.compute([1, 2, 3]) + + @pytest.mark.parametrize('num_threads', [5, 11, 17]) + def test_memoized_meth_concurrency(self, num_threads: int): + """ + Tests concurrent calls to a memoized method + """ + # Each thread should have its own cache; the calls should not block + class Object: + def __init__(self): + self.misses = 0 + self.lock = RLock() + + @memoized_meth + def compute(self, x): + # print ID of the running thread + with self.lock: + self.misses += 1 + + # Simulate some computation + time.sleep(0.2) + return x * 2 + + obj = Object() + def worker(x: int) -> int: + a = obj.compute(x) + b = obj.compute(x) + assert a == b + return a + + with ThreadPoolExecutor(max_workers=num_threads) as executor: + stime = time.perf_counter() + futures = [executor.submit(worker, i % 4) for i in range(num_threads)] + results = [f.result() for f in futures] + etime = time.perf_counter() + + assert len(set(results)) == 4 # Should have gotten four unique results + assert obj.misses == num_threads # Each thread should have missed once + + # Ensure that the total time is approximately 0.2 seconds (one miss per thread) + expected = 0.2 + assert abs(etime - stime - expected) < 0.1 * expected + + def test_memoized_generator(self): + """ + Tests basic functionality of memoized_generator + """ + class Object: + def __init__(self): + self.misses = 0 + + @memoized_generator + def compute(self, x): + self.misses += 1 + yield x * 2 + yield x * 3 + + obj = Object() + list(obj.compute(2)) + assert tuple(obj.compute(2)) == (4, 6) + assert obj.misses == 1 # Only one unique call + + @pytest.mark.parametrize('num_threads', [5, 11, 17]) + def test_memoized_generator_concurrency(self, num_threads: int): + """ + Tests concurrent calls to a memoized generator + """ + class Object: + def __init__(self): + self.misses = 0 + self.lock = RLock() + + @memoized_generator + def compute(self, x): + with self.lock: + self.misses += 1 + + time.sleep(0.25) + yield x * 2 + + time.sleep(0.25) + yield x * 3 + + # With memoized_generator, the initial construction should block but iteration + # should be concurrent and reuse the same iterator. + + obj = Object() + def worker(x: int) -> list[int]: + return list(obj.compute(x)) + + # If one thread consumes the generator, subsequent iteration shouldn't block + # First we iterate concurrently; all but one thread should block to wait for + # the producing thread, so all will take ~0.5 seconds + with ThreadPoolExecutor(max_workers=num_threads) as executor: + stime = time.perf_counter() + futures = [executor.submit(worker, i % 4) for i in range(num_threads)] + results = [f.result() for f in futures] + etime = time.perf_counter() + + expected = 0.5 + assert abs(etime - stime - expected) < 0.1 * expected + assert set(tuple(r) for r in results) == {(0, 0), (2, 3), (4, 6), (6, 9)} + assert obj.misses == 4 # One miss per unique call + + # Now iterating the same calls should use buffered generators from the cache + with ThreadPoolExecutor(max_workers=num_threads) as executor: + stime = time.perf_counter() + futures = [executor.submit(worker, i % 4) for i in range(num_threads)] + results = [f.result() for f in futures] + etime = time.perf_counter() + + assert etime - stime < 0.1 # Should take epsilon time + assert set(tuple(r) for r in results) == {(0, 0), (2, 3), (4, 6), (6, 9)} + assert obj.misses == 4 # No new misses; all calls reused cached generators