-
Notifications
You must be signed in to change notification settings - Fork 132
Add cache_method decorator
#895
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
60bdd54
b00ed57
a258f72
06f50d1
51f6c1f
5592d2b
a397fcf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,124 @@ | ||
| # This code is part of Qiskit. | ||
| # | ||
| # (C) Copyright IBM 2022. | ||
| # | ||
| # This code is licensed under the Apache License, Version 2.0. You may | ||
| # obtain a copy of this license in the LICENSE.txt file in the root directory | ||
| # of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
| # | ||
| # Any modifications or derivative works of this code must retain this | ||
| # copyright notice, and modified files need to carry a notice indicating | ||
| # that they have been altered from the originals. | ||
| """ | ||
| Method decorator for caching regular methods in class instances. | ||
| """ | ||
|
|
||
| from typing import Union, Dict, Callable | ||
| import functools | ||
|
|
||
|
|
||
| def cache_method(cache: Union[Dict, str] = "_cache", cache_args: bool = True) -> Callable: | ||
| """Decorator for caching regular methods in classes. | ||
|
|
||
| .. note:: | ||
|
|
||
| When specifying a cache an existing dictionary value will be | ||
| used as is. A string value will be used to check for an existing | ||
| dict under that attribute name in the class instance. | ||
| If the attribute is not present a new cache dict will be created | ||
| and stored in that class instance. | ||
|
|
||
| Args: | ||
| cache: A dictionary or attribute name string to use as cache. | ||
| cache_args: If True include method arg and kwarg values when | ||
| matching cached values. These values must be hashable. | ||
|
|
||
| Returns: | ||
| The decorator for caching methods. | ||
| """ | ||
| cache_fn = _cache_function(cache) | ||
| cache_key_fn = _cache_key_function(cache_args) | ||
|
|
||
| def cache_method_decorator(method: Callable) -> Callable: | ||
| """Decorator for caching method. | ||
|
|
||
| Args: | ||
| method: A method to cache. | ||
|
|
||
| Returns: | ||
| The wrapped cached method. | ||
| """ | ||
|
|
||
| @functools.wraps(method) | ||
| def _cached_method(self, *args, **kwargs): | ||
| meth_cache = cache_fn(self, method) | ||
| key = cache_key_fn(*args, **kwargs) | ||
| if key in meth_cache: | ||
| return meth_cache[key] | ||
| result = method(self, *args, **kwargs) | ||
| meth_cache[key] = result | ||
| return result | ||
|
|
||
| return _cached_method | ||
|
|
||
| return cache_method_decorator | ||
|
|
||
|
|
||
| def _cache_key_function(cache_args: bool) -> Callable: | ||
| """Return function for generating cache keys. | ||
|
|
||
| Args: | ||
| cache_args: If True include method arg and kwarg values when | ||
| caching the method. If False all calls to the instances | ||
| method will return the same cached value regardless of | ||
| any arg or kwarg values. | ||
|
|
||
| Returns: | ||
| The functions for generating cache keys. | ||
| """ | ||
| if not cache_args: | ||
|
|
||
| def _cache_key(*args, **kwargs): | ||
| # pylint: disable = unused-argument | ||
| return tuple() | ||
|
|
||
| else: | ||
|
|
||
| def _cache_key(*args, **kwargs): | ||
| return args + tuple(list(kwargs.items())) | ||
|
|
||
| return _cache_key | ||
|
|
||
|
|
||
| def _cache_function(cache: Union[Dict, str]) -> Callable: | ||
| """Return function for initializing and accessing cache dict. | ||
|
|
||
| Args: | ||
| cache: The dictionary or cache attribute name to use. If a dict it | ||
| will be used directly, if a str a cache dict will be created | ||
| under that attribute name if one is not already present. | ||
|
|
||
| Returns: | ||
| The function for accessing the cache dict. | ||
| """ | ||
| if isinstance(cache, str): | ||
|
|
||
| def _cache_fn(instance, method): | ||
| if not hasattr(instance, cache): | ||
| setattr(instance, cache, {}) | ||
nkanazawa1989 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| instance_cache = getattr(instance, cache) | ||
| name = method.__name__ | ||
| if name not in instance_cache: | ||
| instance_cache[name] = {} | ||
| return instance_cache[name] | ||
|
|
||
| else: | ||
|
|
||
| def _cache_fn(instance, method): | ||
| # pylint: disable = unused-argument | ||
| name = method.__name__ | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @nkanazawa1989 I wonder if this should be
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think qualname should be useful if we want to support class level cache in future. Also you can validate that the |
||
| if name not in cache: | ||
| cache[name] = {} | ||
| return cache[name] | ||
|
|
||
| return _cache_fn | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| --- | ||
| developer: | ||
| - | | ||
| Adds a ``cache_method`` decorator for caching methods in classes. This | ||
| can be imported from ``qiskit_experiments.framework.cache_method``. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| # This code is part of Qiskit. | ||
| # | ||
| # (C) Copyright IBM 2022. | ||
| # | ||
| # This code is licensed under the Apache License, Version 2.0. You may | ||
| # obtain a copy of this license in the LICENSE.txt file in the root directory | ||
| # of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
| # | ||
| # Any modifications or derivative works of this code must retain this | ||
| # copyright notice, and modified files need to carry a notice indicating | ||
| # that they have been altered from the originals. | ||
|
|
||
| """Tests cache_method decorator.""" | ||
|
|
||
| from test.base import QiskitExperimentsTestCase | ||
| from qiskit_experiments.framework.cache_method import cache_method | ||
|
|
||
|
|
||
| class TestCacheMethod(QiskitExperimentsTestCase): | ||
| """Test for cache_method decorator""" | ||
|
|
||
| def test_cache_args(self): | ||
| """Test cache_args=True""" | ||
|
|
||
| class CachedClass: | ||
| """Class with cached method""" | ||
|
|
||
| def __init__(self): | ||
| self.method_calls = 0 | ||
|
|
||
| @cache_method(cache_args=True) | ||
| def method(self, *args, **kwargs): | ||
| """Test method for caching""" | ||
| self.method_calls += 1 | ||
| return args, kwargs | ||
|
|
||
| obj = CachedClass() | ||
| size = 10 | ||
| cached_vals = [obj.method(i, i) for i in range(size)] | ||
| for i, val in enumerate(cached_vals): | ||
| self.assertEqual(obj.method(i, i), val, msg="method didn't return cached value") | ||
| self.assertEqual(obj.method_calls, size, msg="Cached method was not evaluated once per arg") | ||
|
|
||
| def test_cache_args_kwargs(self): | ||
| """Test cache_args=True with args and kwargs""" | ||
|
|
||
| class CachedClass: | ||
| """Class with cached method""" | ||
|
|
||
| def __init__(self): | ||
| self.method_calls = 0 | ||
|
|
||
| @cache_method(cache_args=True) | ||
| def method(self, *args, **kwargs): | ||
| """Test method for caching""" | ||
| self.method_calls += 1 | ||
| return args, kwargs | ||
|
|
||
| obj = CachedClass() | ||
| args = (1, 2, 3) | ||
| names = ["a", "b", "c", "d"] | ||
| cached_vals = [obj.method(*args, name=name) for name in names] | ||
| for name, val in zip(names, cached_vals): | ||
| self.assertEqual( | ||
| obj.method(*args, name=name), val, msg="method didn't return cached value" | ||
| ) | ||
| self.assertEqual( | ||
| obj.method_calls, len(names), msg="Cached method was not evaluated once per arg" | ||
| ) | ||
|
|
||
| def test_cache_args_false(self): | ||
| """Test cache_args=False""" | ||
|
|
||
| class CachedClass: | ||
| """Class with cached method""" | ||
|
|
||
| def __init__(self): | ||
| self.method_calls = 0 | ||
|
|
||
| @cache_method(cache_args=False) | ||
| def method(self, *args, **kwargs): | ||
| """Test method for caching""" | ||
| self.method_calls += 1 | ||
| return args, kwargs | ||
|
|
||
| obj = CachedClass() | ||
| ret = obj.method(1999) | ||
| for i in range(10): | ||
| self.assertEqual(obj.method(i), ret, msg="method didn't return cached value") | ||
| self.assertEqual(obj.method_calls, 1, msg="Cached method was not evaluated once") | ||
|
|
||
| def test_non_hashable_raises(self): | ||
| """Test non hashable args raise""" | ||
|
|
||
| class CachedClass: | ||
| """Class with cached method""" | ||
|
|
||
| @cache_method() | ||
| def method(self, *args, **kwargs): | ||
| """Test method for caching""" | ||
| return args, kwargs | ||
|
|
||
| obj = CachedClass() | ||
| self.assertRaises(TypeError, obj.method, [1, 2, 3]) | ||
nkanazawa1989 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.assertRaises(TypeError, obj.method, kwarg=[1, 2, 3]) | ||
|
|
||
| def test_cache_name(self): | ||
| """Test decorator with a custom cache name""" | ||
|
|
||
| class CachedClass: | ||
| """Class with cached method""" | ||
|
|
||
| @cache_method(cache="memory") | ||
| def method(self, *args, **kwargs): | ||
| """Test method for caching""" | ||
| return args, kwargs | ||
|
|
||
| obj = CachedClass() | ||
| obj.method(1, 2, 3) | ||
| self.assertTrue(hasattr(obj, "memory")) | ||
| self.assertIn("method", getattr(obj, "memory", {})) | ||
|
|
||
| def test_cache_dict(self): | ||
| """Test decorate with custom cache value""" | ||
|
|
||
| external_cache = {} | ||
|
|
||
| class CachedClass: | ||
| """Class with cached method""" | ||
|
|
||
| @cache_method(cache=external_cache) | ||
| def method(self, *args, **kwargs): | ||
| """Test method for caching""" | ||
| return args, kwargs | ||
|
|
||
| obj = CachedClass() | ||
| obj.method(1, 2, 3) | ||
| self.assertIn("method", external_cache) | ||
Uh oh!
There was an error while loading. Please reload this page.