Skip to content

Commit af01830

Browse files
Fix caching coroutines.
* check_pre_commit_config_frozen.py
1 parent 123b75a commit af01830

File tree

1 file changed

+44
-6
lines changed

1 file changed

+44
-6
lines changed

check_pre_commit_config_frozen.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,22 @@
1515
from asyncio import create_subprocess_exec, gather
1616
from contextlib import asynccontextmanager, contextmanager
1717
from dataclasses import dataclass
18-
from functools import lru_cache, partial
18+
from functools import partial, wraps
1919
from io import StringIO
2020
from pathlib import Path
21-
from typing import Any, AsyncGenerator, Dict, List, Mapping, Optional, Tuple, cast
21+
from typing import (
22+
Any,
23+
AsyncGenerator,
24+
Callable,
25+
Coroutine,
26+
Dict,
27+
List,
28+
Mapping,
29+
Optional,
30+
Tuple,
31+
TypeVar,
32+
cast,
33+
)
2234

2335
from ruamel.yaml import YAML
2436
from ruamel.yaml.error import YAMLError, YAMLWarning
@@ -44,6 +56,32 @@
4456
logger = logging.getLogger(__name__)
4557
logger.addHandler(logging.NullHandler())
4658

59+
60+
# -- Utils -------------------------------------------------------------------
61+
62+
T = TypeVar("T")
63+
64+
65+
class async_cache:
66+
def __init__(self, cache_dict=None):
67+
self._dict = cache_dict or {}
68+
69+
def __call__(
70+
self, func: Callable[..., Coroutine[None, None, T]]
71+
) -> Callable[..., Coroutine[None, None, T]]:
72+
@wraps(func)
73+
async def get(*args, **kwargs):
74+
key = (args, tuple(map(tuple, kwargs.items())))
75+
try:
76+
return self._dict[key]
77+
except KeyError:
78+
value = await func(*args, **kwargs)
79+
self._dict[key] = value
80+
return value
81+
82+
return get
83+
84+
4785
# -- Git ---------------------------------------------------------------------
4886

4987
# The git commands in this section is partially sourced and modified from
@@ -198,7 +236,7 @@ async def tmp_repo(repo: str) -> AsyncGenerator[Path, Any]:
198236
yield Path(tmp)
199237

200238

201-
@lru_cache()
239+
@async_cache()
202240
async def get_tags(repo_url: str, hash: str) -> List[str]:
203241
"""
204242
Retrieve a list of tags for a given commit.
@@ -221,7 +259,7 @@ async def get_tags(repo_url: str, hash: str) -> List[str]:
221259
return await get_tags_in_repo(repo_path, hash)
222260

223261

224-
@lru_cache()
262+
@async_cache()
225263
async def get_tags_in_repo(repo_path: str, hash: str, fetch: bool = True) -> List[str]:
226264
"""
227265
Retrieve a list of tags for a given commit.
@@ -263,7 +301,7 @@ async def get_tags_in_repo(repo_path: str, hash: str, fetch: bool = True) -> Lis
263301
return out.splitlines()
264302

265303

266-
@lru_cache()
304+
@async_cache()
267305
async def get_hash(repo_url: str, rev: str) -> str:
268306
"""
269307
Retrieve the hash for a given tag.
@@ -286,7 +324,7 @@ async def get_hash(repo_url: str, rev: str) -> str:
286324
return await get_hash_in_repo(repo_path, rev)
287325

288326

289-
@lru_cache()
327+
@async_cache()
290328
async def get_hash_in_repo(repo_path: str, rev: str, fetch=True) -> str:
291329
"""
292330
Retrieve the hash for a given tag.

0 commit comments

Comments
 (0)