diff --git a/qiskit_experiments/framework/cache_method.py b/qiskit_experiments/framework/cache_method.py new file mode 100644 index 0000000000..e2dc58650f --- /dev/null +++ b/qiskit_experiments/framework/cache_method.py @@ -0,0 +1,124 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. +""" +Method decorator for caching regular methods in class instances. +""" + +from typing import Union, Dict, Callable +import functools + + +def cache_method(cache: Union[Dict, str] = "_cache", cache_args: bool = True) -> Callable: + """Decorator for caching regular methods in classes. + + .. note:: + + When specifying a cache an existing dictionary value will be + used as is. A string value will be used to check for an existing + dict under that attribute name in the class instance. + If the attribute is not present a new cache dict will be created + and stored in that class instance. + + Args: + cache: A dictionary or attribute name string to use as cache. + cache_args: If True include method arg and kwarg values when + matching cached values. These values must be hashable. + + Returns: + The decorator for caching methods. + """ + cache_fn = _cache_function(cache) + cache_key_fn = _cache_key_function(cache_args) + + def cache_method_decorator(method: Callable) -> Callable: + """Decorator for caching method. + + Args: + method: A method to cache. + + Returns: + The wrapped cached method. + """ + + @functools.wraps(method) + def _cached_method(self, *args, **kwargs): + meth_cache = cache_fn(self, method) + key = cache_key_fn(*args, **kwargs) + if key in meth_cache: + return meth_cache[key] + result = method(self, *args, **kwargs) + meth_cache[key] = result + return result + + return _cached_method + + return cache_method_decorator + + +def _cache_key_function(cache_args: bool) -> Callable: + """Return function for generating cache keys. + + Args: + cache_args: If True include method arg and kwarg values when + caching the method. If False all calls to the instances + method will return the same cached value regardless of + any arg or kwarg values. + + Returns: + The functions for generating cache keys. + """ + if not cache_args: + + def _cache_key(*args, **kwargs): + # pylint: disable = unused-argument + return tuple() + + else: + + def _cache_key(*args, **kwargs): + return args + tuple(list(kwargs.items())) + + return _cache_key + + +def _cache_function(cache: Union[Dict, str]) -> Callable: + """Return function for initializing and accessing cache dict. + + Args: + cache: The dictionary or cache attribute name to use. If a dict it + will be used directly, if a str a cache dict will be created + under that attribute name if one is not already present. + + Returns: + The function for accessing the cache dict. + """ + if isinstance(cache, str): + + def _cache_fn(instance, method): + if not hasattr(instance, cache): + setattr(instance, cache, {}) + instance_cache = getattr(instance, cache) + name = method.__name__ + if name not in instance_cache: + instance_cache[name] = {} + return instance_cache[name] + + else: + + def _cache_fn(instance, method): + # pylint: disable = unused-argument + name = method.__name__ + if name not in cache: + cache[name] = {} + return cache[name] + + return _cache_fn diff --git a/qiskit_experiments/library/tomography/basis/local_basis.py b/qiskit_experiments/library/tomography/basis/local_basis.py index 7c39e46c5c..9356182441 100644 --- a/qiskit_experiments/library/tomography/basis/local_basis.py +++ b/qiskit_experiments/library/tomography/basis/local_basis.py @@ -12,13 +12,13 @@ """ Circuit basis for tomography preparation and measurement circuits """ -import functools from typing import Sequence, Optional, Tuple, Union, List, Dict import numpy as np from qiskit.circuit import QuantumCircuit, Instruction from qiskit.quantum_info import DensityMatrix, Statevector, Operator, SuperOp from qiskit.quantum_info.operators.channel.quantum_channel import QuantumChannel from qiskit.exceptions import QiskitError +from qiskit_experiments.framework.cache_method import cache_method from .base_basis import PreparationBasis, MeasurementBasis # Typing object for POVM args of measurement basis @@ -351,7 +351,7 @@ def matrix(self, index: Sequence[int], outcome: int, qubits: Optional[Sequence[i # a qubit not in the specified kwargs. raise ValueError(f"Invalid qubits for basis {self.name}") from ex - @functools.lru_cache(None) + @cache_method() def _outcome_indices(self, outcome: int, qubits: Tuple[int, ...]) -> Tuple[int, ...]: """Convert an outcome integer to a tuple of single-qubit outcomes""" num_outcomes = self._qubit_num_outcomes.get(qubits[0], self._default_num_outcomes) diff --git a/releasenotes/notes/cache-method-3824562833460741.yaml b/releasenotes/notes/cache-method-3824562833460741.yaml new file mode 100644 index 0000000000..e4b99fb7f6 --- /dev/null +++ b/releasenotes/notes/cache-method-3824562833460741.yaml @@ -0,0 +1,5 @@ +--- +developer: + - | + Adds a ``cache_method`` decorator for caching methods in classes. This + can be imported from ``qiskit_experiments.framework.cache_method``. diff --git a/test/framework/test_cache_method.py b/test/framework/test_cache_method.py new file mode 100644 index 0000000000..42e44832bf --- /dev/null +++ b/test/framework/test_cache_method.py @@ -0,0 +1,138 @@ +# This code is part of Qiskit. +# +# (C) Copyright IBM 2022. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +"""Tests cache_method decorator.""" + +from test.base import QiskitExperimentsTestCase +from qiskit_experiments.framework.cache_method import cache_method + + +class TestCacheMethod(QiskitExperimentsTestCase): + """Test for cache_method decorator""" + + def test_cache_args(self): + """Test cache_args=True""" + + class CachedClass: + """Class with cached method""" + + def __init__(self): + self.method_calls = 0 + + @cache_method(cache_args=True) + def method(self, *args, **kwargs): + """Test method for caching""" + self.method_calls += 1 + return args, kwargs + + obj = CachedClass() + size = 10 + cached_vals = [obj.method(i, i) for i in range(size)] + for i, val in enumerate(cached_vals): + self.assertEqual(obj.method(i, i), val, msg="method didn't return cached value") + self.assertEqual(obj.method_calls, size, msg="Cached method was not evaluated once per arg") + + def test_cache_args_kwargs(self): + """Test cache_args=True with args and kwargs""" + + class CachedClass: + """Class with cached method""" + + def __init__(self): + self.method_calls = 0 + + @cache_method(cache_args=True) + def method(self, *args, **kwargs): + """Test method for caching""" + self.method_calls += 1 + return args, kwargs + + obj = CachedClass() + args = (1, 2, 3) + names = ["a", "b", "c", "d"] + cached_vals = [obj.method(*args, name=name) for name in names] + for name, val in zip(names, cached_vals): + self.assertEqual( + obj.method(*args, name=name), val, msg="method didn't return cached value" + ) + self.assertEqual( + obj.method_calls, len(names), msg="Cached method was not evaluated once per arg" + ) + + def test_cache_args_false(self): + """Test cache_args=False""" + + class CachedClass: + """Class with cached method""" + + def __init__(self): + self.method_calls = 0 + + @cache_method(cache_args=False) + def method(self, *args, **kwargs): + """Test method for caching""" + self.method_calls += 1 + return args, kwargs + + obj = CachedClass() + ret = obj.method(1999) + for i in range(10): + self.assertEqual(obj.method(i), ret, msg="method didn't return cached value") + self.assertEqual(obj.method_calls, 1, msg="Cached method was not evaluated once") + + def test_non_hashable_raises(self): + """Test non hashable args raise""" + + class CachedClass: + """Class with cached method""" + + @cache_method() + def method(self, *args, **kwargs): + """Test method for caching""" + return args, kwargs + + obj = CachedClass() + self.assertRaises(TypeError, obj.method, [1, 2, 3]) + self.assertRaises(TypeError, obj.method, kwarg=[1, 2, 3]) + + def test_cache_name(self): + """Test decorator with a custom cache name""" + + class CachedClass: + """Class with cached method""" + + @cache_method(cache="memory") + def method(self, *args, **kwargs): + """Test method for caching""" + return args, kwargs + + obj = CachedClass() + obj.method(1, 2, 3) + self.assertTrue(hasattr(obj, "memory")) + self.assertIn("method", getattr(obj, "memory", {})) + + def test_cache_dict(self): + """Test decorate with custom cache value""" + + external_cache = {} + + class CachedClass: + """Class with cached method""" + + @cache_method(cache=external_cache) + def method(self, *args, **kwargs): + """Test method for caching""" + return args, kwargs + + obj = CachedClass() + obj.method(1, 2, 3) + self.assertIn("method", external_cache)