Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 40 additions & 17 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,58 +102,73 @@ def _enable():
raise


def _clear():
def _clear(cache_dir=None):
"""
Clear numba cache

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
None
"""
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
if cache_dir is not None: # pragma: no cover
numba_cache_dir = str(cache_dir)
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]


def clear():
def clear(cache_dir=None):
"""
Clear numba cache directory

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory

Returns
-------
None
"""
warnings.warn(CACHE_WARNING)
_clear()
_clear(cache_dir)

return


def _get_cache():
def _get_cache(cache_dir=None):
"""
Retrieve a list of cached numba functions

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
out : list
A list of cached numba functions
"""
warnings.warn(CACHE_WARNING)
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
if cache_dir is not None: # pragma: no cover
numba_cache_dir = str(cache_dir)
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

return [
f"{numba_cache_dir}/{f.name}"
for f in pathlib.Path(numba_cache_dir).glob("*nb*")
if f.is_file()
]


def _recompile():
Expand Down Expand Up @@ -190,32 +205,35 @@ def _recompile():
return


def _save():
def _save(cache_dir):
"""
Save all njit functions

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
None
"""
_enable()
_clear(cache_dir)
_recompile()

return


def save():
def save(cache_dir=None):
"""
Save/overwrite all the cache data files of
all-so-far compiled njit functions.

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory

Returns
-------
Expand All @@ -227,6 +245,11 @@ def save():
else: # pragma: no cover
warnings.warn(CACHE_WARNING)

_save()
if numba.config.CACHE_DIR != "": # pragma: no cover
msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. "
msg += "The `stumpy` cache files may not be saved/cleared correctly!"
warnings.warn(msg)

_save(cache_dir)

return
8 changes: 6 additions & 2 deletions stumpy/fastmath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import warnings

import numba
from numba import njit
Expand Down Expand Up @@ -55,12 +56,15 @@ def _set(module_name, func_name, flag):
func = getattr(module, func_name)
try:
func.targetoptions["fastmath"] = flag
func.recompile()
msg = "One or more fastmath flags have been set/reset. "
msg += "Please call `cache._recompile()` to ensure that all njit functions "
msg += "are properly recompiled."
warnings.warn(msg)
except AttributeError as e:
if numba.config.DISABLE_JIT and (
str(e) == "'function' object has no attribute 'targetoptions'"
or str(e) == "'function' object has no attribute 'recompile'"
):
warnings.warn("Fastmath flags could not be set as Numba JIT is disabled")
pass
else: # pragma: no cover
raise
Expand Down
21 changes: 12 additions & 9 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@ def test_cache_get_njit_funcs():
def test_cache_save_after_clear():
T = np.random.rand(10)
m = 3
stump(T, m)

cache.save()
ref_cache = cache._get_cache()
cache_dir = "stumpy/__pycache__"

cache.clear()
# testing cache._clear()
assert len(cache._get_cache()) == 0
cache.save(cache_dir)
stump(T, m)
ref_cache = cache._get_cache(cache_dir)

cache.save()
comp_cache = cache._get_cache()
cache.clear(cache_dir)
assert len(cache._get_cache(cache_dir)) == 0

cache.save(cache_dir)
stump(T, m)
comp_cache = cache._get_cache(cache_dir)

# testing cache._save() after cache._clear()
assert sorted(ref_cache) == sorted(comp_cache)

cache.clear(cache_dir)
8 changes: 7 additions & 1 deletion tests/test_fastmath.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import numpy as np

from stumpy import fastmath
from stumpy import cache, fastmath


def test_set():
Expand All @@ -11,11 +11,13 @@ def test_set():

# case1: flag=False
fastmath._set("fastmath", "_add_assoc", flag=False)
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case2: flag={'reassoc', 'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc", "nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
if numba.config.DISABLE_JIT:
assert np.isnan(out)
Expand All @@ -24,11 +26,13 @@ def test_set():

# case3: flag={'reassoc'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case4: flag={'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

Expand All @@ -39,7 +43,9 @@ def test_reset():
# https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#fastmath
# and then reset it to the default value, i.e. `True`
fastmath._set("fastmath", "_add_assoc", False)
cache._recompile()
fastmath._reset("fastmath", "_add_assoc")
cache._recompile()
if numba.config.DISABLE_JIT:
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
else: # pragma: no cover
Expand Down