Skip to content

Commit ce89694

Browse files
committed
Added numba cache dir for pytest
1 parent bbc97e4 commit ce89694

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

stumpy/cache.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import ast
66
import importlib
77
import inspect
8+
import os
89
import pathlib
910
import site
1011
import warnings
@@ -102,57 +103,71 @@ def _enable():
102103
raise
103104

104105

105-
def _clear():
106+
def _clear(cache_dir=None):
106107
"""
107108
Clear numba cache
108109
109110
Parameters
110111
----------
111-
None
112+
cache_dir : str
113+
The path to the numba cache directory
112114
113115
Returns
114116
-------
115117
None
116118
"""
117-
site_pkg_dir = site.getsitepackages()[0]
118-
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
119+
if cache_dir is not None: # pragma: no cover
120+
numba_cache_dir = str(cache_dir)
121+
elif "PYTEST_CURRENT_TEST" in os.environ:
122+
numba_cache_dir = "stumpy/__pycache__"
123+
else: # pragma: no cover
124+
site_pkg_dir = site.getsitepackages()[0]
125+
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
126+
119127
[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
120128

121129

122-
def clear():
130+
def clear(cache_dir=None):
123131
"""
124132
Clear numba cache directory
125133
126134
Parameters
127135
----------
128-
None
136+
cache_dir : str
137+
The path to the numba cache directory
129138
130139
Returns
131140
-------
132141
None
133142
"""
134143
warnings.warn(CACHE_WARNING)
135-
_clear()
144+
_clear(cache_dir)
136145

137146
return
138147

139148

140-
def _get_cache():
149+
def _get_cache(cache_dir=None):
141150
"""
142151
Retrieve a list of cached numba functions
143152
144153
Parameters
145154
----------
146-
None
155+
cache_dir : str
156+
The path to the numba cache directory
147157
148158
Returns
149159
-------
150160
out : list
151161
A list of cached numba functions
152162
"""
153163
warnings.warn(CACHE_WARNING)
154-
site_pkg_dir = site.getsitepackages()[0]
155-
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
164+
if cache_dir is not None: # pragma: no cover
165+
numba_cache_dir = str(cache_dir)
166+
if "PYTEST_CURRENT_TEST" in os.environ:
167+
numba_cache_dir = "stumpy/__pycache__"
168+
else: # pragma: no cover
169+
site_pkg_dir = site.getsitepackages()[0]
170+
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
156171
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
157172

158173

tests/test_cache.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,24 @@ def test_cache_get_njit_funcs():
99

1010

1111
def test_cache_save_after_clear():
12+
cache.clear()
13+
cache.save()
14+
1215
T = np.random.rand(10)
1316
m = 3
1417
stump(T, m)
1518

16-
cache.save()
1719
ref_cache = cache._get_cache()
1820

1921
cache.clear()
2022
# testing cache._clear()
2123
assert len(cache._get_cache()) == 0
2224

2325
cache.save()
26+
stump(T, m)
2427
comp_cache = cache._get_cache()
2528

2629
# testing cache._save() after cache._clear()
2730
assert sorted(ref_cache) == sorted(comp_cache)
31+
32+
cache.clear()

0 commit comments

Comments
 (0)