Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 54 additions & 14 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CACHE_WARNING += "and should never be used or depended upon as it is not supported! "
CACHE_WARNING += "All caching capabilities are not tested and may be removed/changed "
CACHE_WARNING += "without prior notice. Please proceed with caution!"
CACHE_CLEARED = True


def get_njit_funcs():
Expand Down Expand Up @@ -102,58 +103,77 @@ def _enable():
raise


def _clear():
def _clear(cache_dir=None):
"""
Clear numba cache

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
None
"""
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
global CACHE_CLEARED

if cache_dir is not None: # pragma: no cover
numba_cache_dir = str(cache_dir)
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]

CACHE_CLEARED = True


def clear():
def clear(cache_dir=None):
"""
Clear numba cache directory

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory

Returns
-------
None
"""
warnings.warn(CACHE_WARNING)
_clear()
_clear(cache_dir)

return


def _get_cache():
def _get_cache(cache_dir=None):
"""
Retrieve a list of cached numba functions

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
out : list
A list of cached numba functions
"""
warnings.warn(CACHE_WARNING)
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
if cache_dir is not None: # pragma: no cover
numba_cache_dir = str(cache_dir)
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

return [
f"{numba_cache_dir}/{f.name}"
for f in pathlib.Path(numba_cache_dir).glob("*nb*")
if f.is_file()
]


def _recompile():
Expand Down Expand Up @@ -202,16 +222,24 @@ def _save():
-------
None
"""
global CACHE_CLEARED

if not CACHE_CLEARED: # pragma: no cover
msg = "Numba njit cached files are not cleared before saving/overwriting. "
msg = "You may need to call `cache.clear()` before calling `cache.save()`."
warnings.warn(msg)

_enable()
_recompile()

CACHE_CLEARED = False

return


def save():
"""
Save/overwrite all the cache data files of
all-so-far compiled njit functions.
Save/overwrite all of the cached njit functions.

Parameters
----------
Expand All @@ -220,13 +248,25 @@ def save():
Returns
-------
None

Notes
-----
The cache is never cleared before saving/overwriting and may be explicitly
cleared by calling `cache.clear()` before saving.
"""
global CACHE_CLEARED

if numba.config.DISABLE_JIT:
msg = "Could not save/cache function because NUMBA JIT is disabled"
warnings.warn(msg)
else: # pragma: no cover
warnings.warn(CACHE_WARNING)

if numba.config.CACHE_DIR != "": # pragma: no cover
msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. "
msg += "The `stumpy` cache files may not be saved/cleared correctly!"
warnings.warn(msg)

_save()

return
8 changes: 6 additions & 2 deletions stumpy/fastmath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import warnings

import numba
from numba import njit
Expand Down Expand Up @@ -55,12 +56,15 @@ def _set(module_name, func_name, flag):
func = getattr(module, func_name)
try:
func.targetoptions["fastmath"] = flag
func.recompile()
msg = "One or more fastmath flags have been set/reset. "
msg += "Please call `cache._recompile()` to ensure that all njit functions "
msg += "are properly recompiled."
warnings.warn(msg)
except AttributeError as e:
if numba.config.DISABLE_JIT and (
str(e) == "'function' object has no attribute 'targetoptions'"
or str(e) == "'function' object has no attribute 'recompile'"
):
warnings.warn("Fastmath flags could not be set as Numba JIT is disabled")
pass
else: # pragma: no cover
raise
Expand Down
25 changes: 18 additions & 7 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numba
import numpy as np

from stumpy import cache, stump
Expand All @@ -11,17 +12,27 @@ def test_cache_get_njit_funcs():
def test_cache_save_after_clear():
T = np.random.rand(10)
m = 3
stump(T, m)

cache_dir = "stumpy/__pycache__"

cache.clear(cache_dir)
cache.save()
ref_cache = cache._get_cache()

cache.clear()
# testing cache._clear()
assert len(cache._get_cache()) == 0
stump(T, m)
ref_cache = cache._get_cache(cache_dir)

if numba.config.DISABLE_JIT:
assert len(ref_cache) == 0
else: # pragma: no cover
assert len(ref_cache) > 0

cache.clear(cache_dir)
assert len(cache._get_cache(cache_dir)) == 0
cache.save()
comp_cache = cache._get_cache()

# testing cache._save() after cache._clear()
stump(T, m)
comp_cache = cache._get_cache(cache_dir)

assert sorted(ref_cache) == sorted(comp_cache)

cache.clear(cache_dir)
8 changes: 7 additions & 1 deletion tests/test_fastmath.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import numpy as np

from stumpy import fastmath
from stumpy import cache, fastmath


def test_set():
Expand All @@ -11,11 +11,13 @@ def test_set():

# case1: flag=False
fastmath._set("fastmath", "_add_assoc", flag=False)
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case2: flag={'reassoc', 'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc", "nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
if numba.config.DISABLE_JIT:
assert np.isnan(out)
Expand All @@ -24,11 +26,13 @@ def test_set():

# case3: flag={'reassoc'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case4: flag={'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

Expand All @@ -39,7 +43,9 @@ def test_reset():
# https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#fastmath
# and then reset it to the default value, i.e. `True`
fastmath._set("fastmath", "_add_assoc", False)
cache._recompile()
fastmath._reset("fastmath", "_add_assoc")
cache._recompile()
if numba.config.DISABLE_JIT:
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
else: # pragma: no cover
Expand Down