Skip to content

Commit ade9ae4

Browse files
committed
Merge branch 'feature/ignore_args_by_name' into develop
2 parents ec89562 + 94f29e8 commit ade9ae4

File tree

2 files changed

+100
-21
lines changed

2 files changed

+100
-21
lines changed

src/redis_func_cache/cache.py

Lines changed: 70 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,19 @@ async def aput(
441441
ext_args = ext_args or ()
442442
await script(keys=key_pair, args=chain((maxsize, ttl, hash_, value, encoded_options), ext_args))
443443

444-
def _before_get(self, user_function, user_args, user_kwds):
445-
keys = self.policy.calc_keys(user_function, user_args, user_kwds)
446-
hash_value = self.policy.calc_hash(user_function, user_args, user_kwds)
447-
ext_args = self.policy.calc_ext_args(user_function, user_args, user_kwds) or ()
444+
def _before_get(
445+
self,
446+
user_function,
447+
user_args,
448+
user_kwds,
449+
exclude_positional_args: Optional[Sequence[int]] = None,
450+
exclude_keyword_args: Optional[Sequence[str]] = None,
451+
):
452+
args = [x for i, x in enumerate(user_args) if i not in (exclude_positional_args or [])]
453+
kwds = {k: v for k, v in user_kwds.items() if k not in (exclude_keyword_args or {})}
454+
keys = self.policy.calc_keys(user_function, args, kwds)
455+
hash_value = self.policy.calc_hash(user_function, args, kwds)
456+
ext_args = self.policy.calc_ext_args(user_function, args, kwds) or ()
448457
return keys, hash_value, ext_args
449458

450459
def exec(
@@ -454,6 +463,8 @@ def exec(
454463
user_kwds: Mapping[str, Any],
455464
serialize_func: Optional[SerializerT] = None,
456465
deserialize_func: Optional[DeserializerT] = None,
466+
exclude_positional_args: Optional[Sequence[int]] = None,
467+
exclude_keyword_args: Optional[Sequence[str]] = None,
457468
**options,
458469
):
459470
"""Execute the given user function with the provided arguments.
@@ -476,7 +487,9 @@ def exec(
476487
script_0, script_1 = self.policy.lua_scripts
477488
if not (isinstance(script_0, redis.commands.core.Script) and isinstance(script_1, redis.commands.core.Script)):
478489
raise RuntimeError("Can not eval redis lua script in asynchronous mode on a synchronous redis client")
479-
keys, hash_value, ext_args = self._before_get(user_function, user_args, user_kwds)
490+
keys, hash_value, ext_args = self._before_get(
491+
user_function, user_args, user_kwds, exclude_positional_args, exclude_keyword_args
492+
)
480493
cached_return_value = self.get(script_0, keys, hash_value, self.ttl, options, ext_args)
481494
if cached_return_value is not None:
482495
return self.deserialize(cached_return_value, deserialize_func)
@@ -492,6 +505,8 @@ async def aexec(
492505
user_kwds: Mapping[str, Any],
493506
serialize_func: Optional[SerializerT] = None,
494507
deserialize_func: Optional[DeserializerT] = None,
508+
exclude_positional_args: Optional[Sequence[int]] = None,
509+
exclude_keyword_args: Optional[Sequence[str]] = None,
495510
**options,
496511
):
497512
"""Asynchronous version of :meth:`.exec`"""
@@ -501,7 +516,9 @@ async def aexec(
501516
and isinstance(script_1, redis.commands.core.AsyncScript)
502517
):
503518
raise RuntimeError("Can not eval redis lua script in synchronous mode on an asynchronous redis client")
504-
keys, hash_value, ext_args = self._before_get(user_function, user_args, user_kwds)
519+
keys, hash_value, ext_args = self._before_get(
520+
user_function, user_args, user_kwds, exclude_positional_args, exclude_keyword_args
521+
)
505522
cached = await self.aget(script_0, keys, hash_value, self.ttl, options, ext_args)
506523
if cached is not None:
507524
return self.deserialize(cached, deserialize_func)
@@ -515,7 +532,9 @@ def decorate(
515532
user_function: Optional[CallableTV] = None,
516533
/,
517534
serializer: Optional[SerializerSetterValueT] = None,
518-
**keywords,
535+
exclude_positional_args: Optional[Sequence[int]] = None,
536+
exclude_keyword_args: Optional[Sequence[str]] = None,
537+
**options,
519538
) -> CallableTV:
520539
"""Decorate the given function with caching.
521540
@@ -524,9 +543,21 @@ def decorate(
524543
525544
serializer: serialize/deserialize name or function pair for return value of what decorated.
526545
527-
If defined, it overrides the first element of :attr:`serializer`.
546+
It accepts either:
547+
- A string key mapping to predefined serializers (like "yaml" or "json")
548+
- A tuple of (`serialize_func`, `deserialize_func`) functions
549+
550+
If assigned, it overwrite the :attr:`serializer` property of the cache instance on the decorated function.
528551
529-
**keywords: Additional options passed to :meth:`exec`, they will encoded to json, then pass to redis lua script.
552+
exclude_positional_args: A list of positional argument indices to exclude from cache key generation.
553+
554+
These arguments will be filtered out before cache operations.
555+
556+
exclude_keyword_args: A list of keyword argument names to exclude from cache key generation.
557+
558+
These parameters will be filtered out before cache operations.
559+
560+
**options: Additional options passed to :meth:`exec`, they will encoded to json, then pass to redis lua script.
530561
531562
This method is equivalent to :attr:`__call__`.
532563
@@ -585,25 +616,43 @@ def my_func(a, b):
585616
elif serializer is not None:
586617
serialize_func, deserialize_func = serializer
587618

588-
def decorator(f: CallableTV):
589-
@wraps(f)
590-
def wrapper(*args, **kwargs):
591-
return self.exec(f, args, kwargs, serialize_func, deserialize_func, **keywords)
592-
593-
@wraps(f)
594-
async def awrapper(*args, **kwargs):
595-
return await self.aexec(f, args, kwargs, serialize_func, deserialize_func, **keywords)
596-
597-
if not callable(f):
619+
def decorator(user_func: CallableTV):
620+
@wraps(user_func)
621+
def wrapper(*user_args, **user_kwargs):
622+
return self.exec(
623+
user_func,
624+
user_args,
625+
user_kwargs,
626+
serialize_func,
627+
deserialize_func,
628+
exclude_positional_args,
629+
exclude_keyword_args,
630+
**options,
631+
)
632+
633+
@wraps(user_func)
634+
async def awrapper(*user_args, **user_kwargs):
635+
return await self.aexec(
636+
user_func,
637+
user_args,
638+
user_kwargs,
639+
serialize_func,
640+
deserialize_func,
641+
exclude_positional_args,
642+
exclude_keyword_args,
643+
**options,
644+
)
645+
646+
if not callable(user_func):
598647
raise TypeError("Can not decorate a non-callable object.")
599648
if self.asynchronous:
600-
if not iscoroutinefunction(f):
649+
if not iscoroutinefunction(user_func):
601650
raise TypeError(
602651
"The decorated function or method must be a coroutine when using an asynchronous redis client."
603652
)
604653
return cast(CallableTV, awrapper)
605654
else:
606-
if iscoroutinefunction(f):
655+
if iscoroutinefunction(user_func):
607656
raise TypeError(
608657
"The decorated function or method cannot be a coroutine when using a asynchronous redis client."
609658
)

tests/test_basic.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,3 +356,33 @@ def test_lambda(self):
356356
for _ in range(cache.maxsize * 2 + 1):
357357
v = uuid4().hex
358358
self.assertEqual(v, f(v))
359+
360+
361+
class ExcludeArgsTestCase(TestCase):
362+
def setUp(self):
363+
for cache in CACHES.values():
364+
cache.policy.purge()
365+
366+
def test_exclude_positional_args(self):
367+
def user_func(func, value):
368+
return func(value)
369+
370+
unpickable_func = lambda x: x # noqa: E731
371+
372+
for cache in CACHES.values():
373+
f = cache(user_func, exclude_positional_args=[0])
374+
for _ in range(cache.maxsize * 2 + 1):
375+
v = uuid4().hex
376+
self.assertEqual(f(unpickable_func, v), v)
377+
378+
def test_exclude_keyword_args(self):
379+
def user_func(func, value):
380+
return func(value)
381+
382+
unpickable_func = lambda x: x # noqa: E731
383+
384+
for cache in CACHES.values():
385+
f = cache(user_func, exclude_keyword_args=["func"])
386+
for _ in range(cache.maxsize * 2 + 1):
387+
v = uuid4().hex
388+
self.assertEqual(f(func=unpickable_func, value=v), v)

0 commit comments

Comments
 (0)