Skip to content

Commit 80446fc

Browse files
committed
Fixed #1066 Permanently Remove Recompile
1 parent bbc97e4 commit 80446fc

File tree

3 files changed

+8
-42
lines changed

3 files changed

+8
-42
lines changed

stumpy/cache.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -156,40 +156,6 @@ def _get_cache():
156156
return [f.name for f in pathlib.Path(numba_cache_dir).glob("*nb*") if f.is_file()]
157157

158158

159-
def _recompile():
160-
"""
161-
Recompile all njit functions
162-
163-
Parameters
164-
----------
165-
None
166-
167-
Returns
168-
-------
169-
None
170-
171-
Notes
172-
-----
173-
If the `numba` cache is enabled, this results in saving (and/or overwriting)
174-
the cached numba functions to disk.
175-
"""
176-
for module_name, func_name in get_njit_funcs():
177-
module = importlib.import_module(f".{module_name}", package="stumpy")
178-
func = getattr(module, func_name)
179-
try:
180-
func.recompile()
181-
except AttributeError as e:
182-
if (
183-
numba.config.DISABLE_JIT
184-
and str(e) == "'function' object has no attribute 'recompile'"
185-
):
186-
pass
187-
else: # pragma: no cover
188-
raise
189-
190-
return
191-
192-
193159
def _save():
194160
"""
195161
Save all njit functions
@@ -203,7 +169,6 @@ def _save():
203169
None
204170
"""
205171
_enable()
206-
_recompile()
207172

208173
return
209174

stumpy/fastmath.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,15 @@ def _set(module_name, func_name, flag):
5454
module = importlib.import_module(f".{module_name}", package="stumpy")
5555
func = getattr(module, func_name)
5656
try:
57-
func.targetoptions["fastmath"] = flag
58-
func.recompile()
57+
py_func = func.py_func # Copy raw Python function (independent of `njit`)
58+
njit_signature = func.targetoptions.copy() # Copy the `njit` arguments
59+
njit_signature.pop("nopython", None) # Pop redundant `nopython` declaration
60+
njit_signature["fastmath"] = flag # Apply new `fastmath` flag
61+
func = njit(py_func, **njit_signature) # Assign `njit` function with new target
62+
setattr(module, func_name, func) # Monkey-patch `njit` into global space
5963
except AttributeError as e:
6064
if numba.config.DISABLE_JIT and (
61-
str(e) == "'function' object has no attribute 'targetoptions'"
62-
or str(e) == "'function' object has no attribute 'recompile'"
65+
str(e) == "'function' object has no attribute 'py_func'"
6366
):
6467
pass
6568
else: # pragma: no cover

tests/test_precision.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from numba import cuda
1010

1111
import stumpy
12-
from stumpy import cache, config, core, fastmath
12+
from stumpy import config, core, fastmath
1313

1414
try:
1515
from numba.errors import NumbaPerformanceWarning
@@ -156,7 +156,6 @@ def test_snippets():
156156
fastmath._set(
157157
"core", "_calculate_squared_distance", {"nsz", "arcp", "contract", "afn"}
158158
)
159-
cache._recompile()
160159

161160
(
162161
cmp_snippets,
@@ -187,7 +186,6 @@ def test_snippets():
187186
if not numba.config.DISABLE_JIT: # pragma: no cover
188187
# Revert fastmath flag back to their default values
189188
fastmath._reset("core", "_calculate_squared_distance")
190-
cache._recompile()
191189

192190

193191
@pytest.mark.filterwarnings("ignore", category=NumbaPerformanceWarning)

0 commit comments

Comments
 (0)