Source code for tad_mctc.tools.caching

# This file is part of tad-mctc.
#
# SPDX-Identifier: Apache-2.0
# Copyright (C) 2024 Grimme Group
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tools: Caching
--------------

Decorators for memoization/caching.
"""

from __future__ import annotations

from functools import wraps

import torch

from ..typing import Any, CacheKey, Callable, TypeVar

__all__ = ["memoize", "memoize_all_instances"]

T = TypeVar("T")


[docs] def memoize(fcn: Callable[..., T]) -> Callable[..., T]: # pragma: no cover """ Memoization decorator that writes the cache to the object itself, hence not allowing the specification of `__slots__`. It works with and without function arguments. Note that `lru_cache` can produce memory leaks for a method. Parameters ---------- fcn : Callable[[Any], T] Function to memoize Returns ------- Callable[[Any], T] Memoized function. """ @wraps(fcn) def wrapper(self: Any, *args: Any, **kwargs: Any) -> T: # name mangling of dunder attributes! name = f"_{self.__class__.__name__}__memoization_cache" if hasattr(self, "__slots__"): # if __slots__ are defined, `__memoization_cache` must be in them if "__memoization_cache" not in self.__slots__: raise AttributeError( "Cannot use memoize with objects that specify " "`__slots__`, unless the `__memoization_cache` " "attribute is included in them." ) if not hasattr(self, name): # Create a cache dictionary as an instance attribute setattr(self, name, {}) cache = getattr(self, name, {}) # Create a unique key for the cache dictionary key = (id(self), fcn.__name__, args, frozenset(kwargs.items())) # If the result is already in the cache, return it if key in cache: return cache[key] # If key is not found, compute the result result = fcn(self, *args, **kwargs) cache[key] = result return result def clear(self: Any) -> None: name = f"_{self.__class__.__name__}__memoization_cache" if hasattr(self, name): setattr(self, name, {}) def get(self: Any) -> dict[str, Any]: name = f"_{self.__class__.__name__}__memoization_cache" if not hasattr(self, name): return {} return getattr(self, name) setattr(wrapper, "clear", clear) setattr(wrapper, "clear_cache", clear) setattr(wrapper, "get_cache", get) return wrapper
[docs] def memoize_all_instances(fcn: Callable[..., T]) -> Callable[..., T]: """ Memoization with shared cache among all instances of the decorated function. This decorator allows specification of `__slots__`. It works with and without function arguments. Note that `lru_cache` can produce memory leaks for a method. Parameters ---------- fcn : Callable[[Any], T] Function to memoize Returns ------- Callable[[Any], T] Memoized function. """ # creating the cache outside the wrapper shares it across instances cache: dict[CacheKey, T] = {} @wraps(fcn) def wrapper(self: Any, *args: Any, **kwargs: Any) -> T: # create unique key for all instances in cache dictionary key = (id(self), fcn.__name__, args, frozenset(kwargs.items())) # if the result is already in the cache, return it if key in cache: return cache[key] # if key is not found, compute the result result = fcn(self, *args, **kwargs) cache[key] = result return result def clear(*_: Any) -> None: cache.clear() def get(*_: Any) -> dict[CacheKey, T]: return cache setattr(wrapper, "clear", clear) setattr(wrapper, "clear_cache", clear) setattr(wrapper, "get_cache", get) return wrapper
def memoize_with_deps( *dependency_getters: Callable[..., Any] ) -> Callable[..., Any]: # pragma: no cover """ Memoization with multiple dependency-based cache invalidation. This decorator allows specification of `__slots__`. It works with and without function arguments. Warning ------- This is an experimental feature, which can cause memory leaks! """ def decorator(fcn: Callable[..., T]) -> Callable[..., T]: # creating the cache outside the wrapper shares it across instances cache: dict[CacheKey, T] = {} dependency_cache: dict[CacheKey, tuple[Any, ...]] = {} @wraps(fcn) def wrapper(self: Any, *args: Any, **kwargs: Any) -> T: # create unique key for all instances in cache dictionary key = (id(self), fcn.__name__, args, frozenset(kwargs.items())) # get current deps current_deps = tuple(getter(self) for getter in dependency_getters) cached_deps = dependency_cache.get(key) # Check if the cache has been invalidated cache_invalidated = False if cached_deps is None or len(cached_deps) != len(current_deps): cache_invalidated = True else: for curr, cached in zip(current_deps, cached_deps): if not torch.equal(curr, cached): cache_invalidated = True break if not cache_invalidated and key in cache: return cache[key] # If result is not in cache or deps have changed, compute result result = fcn(self, *args, **kwargs) cache[key] = result dependency_cache[key] = current_deps return result def clear() -> None: cache.clear() dependency_cache.clear() def get() -> dict[CacheKey, T]: return cache def get_dep() -> dict[CacheKey, tuple[Any, ...]]: return dependency_cache setattr(wrapper, "clear", clear) setattr(wrapper, "clear_cache", clear) setattr(wrapper, "get_cache", get) setattr(wrapper, "get_dep_cache", get_dep) return wrapper return decorator