Skip to content

Commit e7d9b2b

Browse files
committed
Cleaned up from comments
1 parent a505b87 commit e7d9b2b

File tree

2 files changed

+21
-12
lines changed

2 files changed

+21
-12
lines changed

stumpy/cache.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,39 +205,40 @@ def _recompile():
205205
return
206206

207207

208-
def _save(cache_dir):
208+
def _save():
209209
"""
210210
Save all njit functions
211211
212212
Parameters
213213
----------
214-
cache_dir : str
215-
The path to the numba cache directory
214+
None
216215
217216
Returns
218217
-------
219218
None
220219
"""
221220
_enable()
222-
_clear(cache_dir)
223221
_recompile()
224222

225223
return
226224

227225

228-
def save(cache_dir=None):
226+
def save():
229227
"""
230-
Save/overwrite all the cache data files of
231-
all-so-far compiled njit functions.
228+
Save/overwrite all of the cached njit functions.
232229
233230
Parameters
234231
----------
235-
cache_dir : str, default None
236-
The path to the numba cache directory
232+
None
237233
238234
Returns
239235
-------
240236
None
237+
238+
Notes
239+
-----
240+
The cache is never cleared before saving/overwriting and may be explicitly
241+
cleared by calling `cache.clear()` before saving.
241242
"""
242243
if numba.config.DISABLE_JIT:
243244
msg = "Could not save/cache function because NUMBA JIT is disabled"
@@ -250,6 +251,6 @@ def save(cache_dir=None):
250251
msg += "The `stumpy` cache files may not be saved/cleared correctly!"
251252
warnings.warn(msg)
252253

253-
_save(cache_dir)
254+
_save()
254255

255256
return

tests/test_cache.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numba
12
import numpy as np
23

34
from stumpy import cache, stump
@@ -14,14 +15,21 @@ def test_cache_save_after_clear():
1415

1516
cache_dir = "stumpy/__pycache__"
1617

17-
cache.save(cache_dir)
18+
cache.clear(cache_dir)
19+
cache.save()
20+
1821
stump(T, m)
1922
ref_cache = cache._get_cache(cache_dir)
2023

24+
if numba.config.DISABLE_JIT:
25+
assert len(ref_cache) == 0
26+
else: # pragma: no cover
27+
assert len(ref_cache) > 0
28+
2129
cache.clear(cache_dir)
2230
assert len(cache._get_cache(cache_dir)) == 0
31+
cache.save()
2332

24-
cache.save(cache_dir)
2533
stump(T, m)
2634
comp_cache = cache._get_cache(cache_dir)
2735

0 commit comments

Comments
 (0)