15
15
from asyncio import create_subprocess_exec , gather
16
16
from contextlib import asynccontextmanager , contextmanager
17
17
from dataclasses import dataclass
18
- from functools import lru_cache , partial
18
+ from functools import partial , wraps
19
19
from io import StringIO
20
20
from pathlib import Path
21
- from typing import Any , AsyncGenerator , Dict , List , Mapping , Optional , Tuple , cast
21
+ from typing import (
22
+ Any ,
23
+ AsyncGenerator ,
24
+ Callable ,
25
+ Coroutine ,
26
+ Dict ,
27
+ List ,
28
+ Mapping ,
29
+ Optional ,
30
+ Tuple ,
31
+ TypeVar ,
32
+ cast ,
33
+ )
22
34
23
35
from ruamel .yaml import YAML
24
36
from ruamel .yaml .error import YAMLError , YAMLWarning
44
56
logger = logging .getLogger (__name__ )
45
57
logger .addHandler (logging .NullHandler ())
46
58
59
+
60
+ # -- Utils -------------------------------------------------------------------
61
+
62
+ T = TypeVar ("T" )
63
+
64
+
65
+ class async_cache :
66
+ def __init__ (self , cache_dict = None ):
67
+ self ._dict = cache_dict or {}
68
+
69
+ def __call__ (
70
+ self , func : Callable [..., Coroutine [None , None , T ]]
71
+ ) -> Callable [..., Coroutine [None , None , T ]]:
72
+ @wraps (func )
73
+ async def get (* args , ** kwargs ):
74
+ key = (args , tuple (map (tuple , kwargs .items ())))
75
+ try :
76
+ return self ._dict [key ]
77
+ except KeyError :
78
+ value = await func (* args , ** kwargs )
79
+ self ._dict [key ] = value
80
+ return value
81
+
82
+ return get
83
+
84
+
47
85
# -- Git ---------------------------------------------------------------------
48
86
49
87
# The git commands in this section is partially sourced and modified from
@@ -198,7 +236,7 @@ async def tmp_repo(repo: str) -> AsyncGenerator[Path, Any]:
198
236
yield Path (tmp )
199
237
200
238
201
- @lru_cache ()
239
+ @async_cache ()
202
240
async def get_tags (repo_url : str , hash : str ) -> List [str ]:
203
241
"""
204
242
Retrieve a list of tags for a given commit.
@@ -221,7 +259,7 @@ async def get_tags(repo_url: str, hash: str) -> List[str]:
221
259
return await get_tags_in_repo (repo_path , hash )
222
260
223
261
224
- @lru_cache ()
262
+ @async_cache ()
225
263
async def get_tags_in_repo (repo_path : str , hash : str , fetch : bool = True ) -> List [str ]:
226
264
"""
227
265
Retrieve a list of tags for a given commit.
@@ -263,7 +301,7 @@ async def get_tags_in_repo(repo_path: str, hash: str, fetch: bool = True) -> Lis
263
301
return out .splitlines ()
264
302
265
303
266
- @lru_cache ()
304
+ @async_cache ()
267
305
async def get_hash (repo_url : str , rev : str ) -> str :
268
306
"""
269
307
Retrieve the hash for a given tag.
@@ -286,7 +324,7 @@ async def get_hash(repo_url: str, rev: str) -> str:
286
324
return await get_hash_in_repo (repo_path , rev )
287
325
288
326
289
- @lru_cache ()
327
+ @async_cache ()
290
328
async def get_hash_in_repo (repo_path : str , rev : str , fetch = True ) -> str :
291
329
"""
292
330
Retrieve the hash for a given tag.
0 commit comments