Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 45 additions & 17 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import ast
import importlib
import inspect
import os
import pathlib
import site
import warnings
Expand Down Expand Up @@ -102,58 +103,77 @@ def _enable():
raise


def _clear():
def _clear(cache_dir=None):
"""
Clear numba cache

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
None
"""
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
if cache_dir is not None: # pragma: no cover
numba_cache_dir = str(cache_dir)
elif "PYTEST_CURRENT_TEST" in os.environ:
numba_cache_dir = "stumpy/__pycache__"
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

[f.unlink() for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]


def clear():
def clear(cache_dir=None):
"""
Clear numba cache directory

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory

Returns
-------
None
"""
warnings.warn(CACHE_WARNING)
_clear()
_clear(cache_dir)

return


def _get_cache():
def _get_cache(cache_dir=None):
"""
Retrieve a list of cached numba functions

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
out : list
A list of cached numba functions
"""
warnings.warn(CACHE_WARNING)
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
if cache_dir is not None: # pragma: no cover
numba_cache_dir = str(cache_dir)
elif "PYTEST_CURRENT_TEST" in os.environ:
numba_cache_dir = "stumpy/__pycache__"
else: # pragma: no cover
site_pkg_dir = site.getsitepackages()[0]
numba_cache_dir = site_pkg_dir + "/stumpy/__pycache__"

return [
f"{numba_cache_dir}/{f.name}"
for f in pathlib.Path(numba_cache_dir).glob("*nb*")
if f.is_file()
]


def _recompile():
Expand Down Expand Up @@ -190,32 +210,35 @@ def _recompile():
return


def _save():
def _save(cache_dir):
"""
Save all njit functions

Parameters
----------
None
cache_dir : str
The path to the numba cache directory

Returns
-------
None
"""
_enable()
_clear(cache_dir)
_recompile()

return


def save():
def save(cache_dir=None):
"""
Save/overwrite all the cache data files of
all-so-far compiled njit functions.

Parameters
----------
None
cache_dir : str, default None
The path to the numba cache directory

Returns
-------
Expand All @@ -227,6 +250,11 @@ def save():
else: # pragma: no cover
warnings.warn(CACHE_WARNING)

_save()
if numba.config.CACHE_DIR != "":
msg = "Found user specified `NUMBA_CACHE_DIR`/`numba.config.CACHE_DIR`. "
msg += "The `stumpy` cache files may not be saved/cleared correctly!"
warnings.warn(msg)

_save(cache_dir)

return
7 changes: 5 additions & 2 deletions stumpy/fastmath.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib
import warnings

import numba
from numba import njit
Expand Down Expand Up @@ -55,11 +56,13 @@ def _set(module_name, func_name, flag):
func = getattr(module, func_name)
try:
func.targetoptions["fastmath"] = flag
func.recompile()
msg = "One or more fastmath flags have been set/reset. "
msg += "Please call `cache._recompile()` to ensure that all njit functions "
msg += "are properly recompiled."
warnings.warn(msg)
except AttributeError as e:
if numba.config.DISABLE_JIT and (
str(e) == "'function' object has no attribute 'targetoptions'"
or str(e) == "'function' object has no attribute 'recompile'"
):
pass
else: # pragma: no cover
Expand Down
13 changes: 10 additions & 3 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numba
import numpy as np

from stumpy import cache, stump
Expand All @@ -11,17 +12,23 @@ def test_cache_get_njit_funcs():
def test_cache_save_after_clear():
T = np.random.rand(10)
m = 3
stump(T, m)

cache.save()
stump(T, m)
ref_cache = cache._get_cache()

if numba.config.DISABLE_JIT:
assert len(ref_cache) == 0
else: # pragma: no cover
assert len(ref_cache) > 0

cache.clear()
# testing cache._clear()
assert len(cache._get_cache()) == 0

cache.save()
stump(T, m)
comp_cache = cache._get_cache()

# testing cache._save() after cache._clear()
assert sorted(ref_cache) == sorted(comp_cache)

cache.clear()
8 changes: 7 additions & 1 deletion tests/test_fastmath.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import numpy as np

from stumpy import fastmath
from stumpy import cache, fastmath


def test_set():
Expand All @@ -11,11 +11,13 @@ def test_set():

# case1: flag=False
fastmath._set("fastmath", "_add_assoc", flag=False)
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case2: flag={'reassoc', 'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc", "nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
if numba.config.DISABLE_JIT:
assert np.isnan(out)
Expand All @@ -24,11 +26,13 @@ def test_set():

# case3: flag={'reassoc'}
fastmath._set("fastmath", "_add_assoc", flag={"reassoc"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

# case4: flag={'nsz'}
fastmath._set("fastmath", "_add_assoc", flag={"nsz"})
cache._recompile()
out = fastmath._add_assoc(0, np.inf)
assert np.isnan(out)

Expand All @@ -39,7 +43,9 @@ def test_reset():
# https://numba.pydata.org/numba-doc/dev/user/performance-tips.html#fastmath
# and then reset it to the default value, i.e. `True`
fastmath._set("fastmath", "_add_assoc", False)
cache._recompile()
fastmath._reset("fastmath", "_add_assoc")
cache._recompile()
if numba.config.DISABLE_JIT:
assert np.isnan(fastmath._add_assoc(0.0, np.inf))
else: # pragma: no cover
Expand Down