11import asyncio
22import inspect
3+ from collections import defaultdict
34from copy import copy
45from logging import getLogger
5- from typing import TYPE_CHECKING , Any , Dict , Generator , List , Optional
6+ from typing import TYPE_CHECKING , Any , DefaultDict , Dict , Generator , List , Optional
67
78from taskiq_dependencies .utils import ParamInfo
89
@@ -49,6 +50,8 @@ def traverse_deps( # noqa: C901, WPS210
4950 # to separate dependencies that use cache,
5051 # from dependencies that aren't.
5152 cache = copy (self .initial_cache )
53+ # Cache for all dependencies with kwargs.
54+ kwargs_cache : "DefaultDict[Any, List[Any]]" = defaultdict (list )
5255 # We iterate over topologicaly sorted list of dependencies.
5356 for index , dep in enumerate (self .graph .ordered_deps ):
5457 # If this dependency doesn't use cache,
@@ -62,6 +65,19 @@ def traverse_deps( # noqa: C901, WPS210
6265 # If dependency is already calculated.
6366 if dep .dependency in cache :
6467 continue
68+ # For dependencies with kwargs we check kwarged cache.
69+ elif dep .kwargs and dep .dependency in kwargs_cache :
70+ cache_hit = False
71+ # We have to iterate over all cached dependencies with
72+ # kwargs, because users may pass unhashable objects as kwargs.
73+ # That's why we cannot use them as dict keys.
74+ for cached_kwargs , _ in kwargs_cache [dep .dependency ]:
75+ if cached_kwargs == dep .kwargs :
76+ cache_hit = True
77+ break
78+ if cache_hit :
79+ continue
80+
6581 kwargs = {}
6682 # Now we get list of dependencies for current top-level dependency
6783 # and iterate over it.
@@ -78,7 +94,13 @@ def traverse_deps( # noqa: C901, WPS210
7894 if subdep .use_cache :
7995 # If this dependency can be calculated, using cache,
8096 # we try to get it from cache.
81- kwargs [subdep .param_name ] = cache [subdep .dependency ]
97+ if subdep .kwargs and subdep .dependency in kwargs_cache :
98+ for cached_kwargs , kw_cache in kwargs_cache [subdep .dependency ]:
99+ if cached_kwargs == subdep .kwargs :
100+ kwargs [subdep .param_name ] = kw_cache
101+ break
102+ else :
103+ kwargs [subdep .param_name ] = cache [subdep .dependency ]
82104 else :
83105 # If this dependency doesn't use cache,
84106 # we resolve it's dependencies and
@@ -101,9 +123,13 @@ def traverse_deps( # noqa: C901, WPS210
101123 # because we calculate them when needed.
102124 and dep .dependency != ParamInfo
103125 ):
104- user_kwargs = dep .kwargs
126+ user_kwargs = copy ( dep .kwargs )
105127 user_kwargs .update (kwargs )
106- cache [dep .dependency ] = yield dep .dependency (** user_kwargs )
128+ resolved = yield dep .dependency (** user_kwargs )
129+ if dep .kwargs :
130+ kwargs_cache [dep .dependency ].append ((dep .kwargs , resolved ))
131+ else :
132+ cache [dep .dependency ] = resolved
107133 return kwargs
108134
109135
0 commit comments