Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions fastmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
#!/usr/bin/env python

import argparse
import ast
import importlib
import pathlib


def get_njit_funcs(pkg_dir):
"""
Identify all njit functions

Parameters
----------
pkg_dir : str
The path to the directory containing some .py files

Returns
-------
njit_funcs : list
A list of all njit functions, where each element is a tuple of the form
(module_name, func_name)
"""
ignore_py_files = ["__init__", "__pycache__"]
pkg_dir = pathlib.Path(pkg_dir)

module_names = []
for fname in pkg_dir.iterdir():
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
module_names.append(fname.stem)

njit_funcs = []
for module_name in module_names:
filepath = pkg_dir / f"{module_name}.py"
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
module = ast.parse(file_contents)
for node in module.body:
if isinstance(node, ast.FunctionDef):
func_name = node.name
for decorator in node.decorator_list:
decorator_name = None
if isinstance(decorator, ast.Name):
# Bare decorator
decorator_name = decorator.id
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
# Decorator is a function
decorator_name = decorator.func.id

if decorator_name == "njit":
njit_funcs.append((module_name, func_name))

return njit_funcs


def check_fastmath(pkg_dir, pkg_name):
"""
Check if all njit functions have the `fastmath` flag set

Parameters
----------
pkg_dir : str
The path to the directory containing some .py files

pkg_name : str
The name of the package

Returns
-------
None
"""
missing_fastmath = [] # list of njit functions with missing fastmath flags
for module_name, func_name in get_njit_funcs(pkg_dir):
module = importlib.import_module(f".{module_name}", package=pkg_name)
func = getattr(module, func_name)
if "fastmath" not in func.targetoptions.keys():
missing_fastmath.append(f"{module_name}.{func_name}")

if len(missing_fastmath) > 0:
msg = (
"Found one or more `@njit` functions that are missing the `fastmath` flag. "
+ f"The functions are:\n {missing_fastmath}\n"
)
raise ValueError(msg)

return


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", dest="pkg_dir")
args = parser.parse_args()

if args.pkg_dir:
pkg_dir = pathlib.Path(args.pkg_dir)
pkg_name = pkg_dir.name
check_fastmath(str(pkg_dir), pkg_name)
22 changes: 16 additions & 6 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import ast
import importlib
import pathlib
import pkgutil
import site
import warnings

Expand All @@ -28,13 +27,17 @@ def get_njit_funcs():
out : list
A list of (`module_name`, `func_name`) pairs
"""
ignore_py_files = ["__init__", "__pycache__"]

pkg_dir = pathlib.Path(__file__).parent
module_names = [name for _, name, _ in pkgutil.iter_modules([str(pkg_dir)])]
module_names = []
for fname in pkg_dir.iterdir():
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
module_names.append(fname.stem)

njit_funcs = []

for module_name in module_names:
filepath = pathlib.Path(__file__).parent / f"{module_name}.py"
filepath = pkg_dir / f"{module_name}.py"
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
Expand All @@ -43,11 +46,18 @@ def get_njit_funcs():
if isinstance(node, ast.FunctionDef):
func_name = node.name
for decorator in node.decorator_list:
decorator_name = None
if isinstance(decorator, ast.Name):
# Bare decorator
decorator_name = decorator.id
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
if decorator.func.id == "njit":
njit_funcs.append((module_name, func_name))
# Decorator is a function
decorator_name = decorator.func.id

if decorator_name == "njit":
njit_funcs.append((module_name, func_name))

return njit_funcs

Expand Down
10 changes: 6 additions & 4 deletions stumpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2347,6 +2347,7 @@ def _count_diagonal_ndist(diags, m, n_A, n_B):

@njit(
# "i8[:, :](i8[:], i8, b1)"
fastmath=True
)
def _get_array_ranges(a, n_chunks, truncate):
"""
Expand Down Expand Up @@ -2395,6 +2396,7 @@ def _get_array_ranges(a, n_chunks, truncate):

@njit(
# "i8[:, :](i8, i8, b1)"
fastmath=True
)
def _get_ranges(size, n_chunks, truncate):
"""
Expand Down Expand Up @@ -3247,7 +3249,7 @@ def _select_P_ABBA_value(P_ABBA, k, custom_func=None):
return MPdist


@njit()
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _merge_topk_PI(PA, PB, IA, IB):
"""
Merge two top-k matrix profiles `PA` and `PB`, and update `PA` (in place).
Expand Down Expand Up @@ -3320,7 +3322,7 @@ def _merge_topk_PI(PA, PB, IA, IB):
IA[i] = tmp_I


@njit()
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _merge_topk_ρI(ρA, ρB, IA, IB):
"""
Merge two top-k pearson profiles `ρA` and `ρB`, and update `ρA` (in place).
Expand Down Expand Up @@ -3394,7 +3396,7 @@ def _merge_topk_ρI(ρA, ρB, IA, IB):
IA[i] = tmp_I


@njit()
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _shift_insert_at_index(a, idx, v, shift="right"):
"""
If `shift=right` (default), all elements in `a[idx:]` are shifted to the right by
Expand Down Expand Up @@ -4370,7 +4372,7 @@ def get_ray_nworkers(ray_client):
return int(ray_client.cluster_resources().get("CPU"))


@njit
@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"})
def _update_incremental_PI(D, P, I, excl_zone, n_appended=0):
"""
Given the 1D array distance profile, `D`, of the last subsequence of T,
Expand Down
17 changes: 15 additions & 2 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ check_print()
fi
}

check_fastmath()
{
echo "Checking Missing fastmath flags in njit functions"
./fastmath.py --check stumpy
check_errs $?
}

check_naive()
{
# Check if there are any naive implementations not at start of test file
Expand Down Expand Up @@ -146,14 +153,14 @@ set_ray_coveragerc()
show_coverage_report()
{
set_ray_coveragerc
coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
coverage report -m --fail-under=100 --skip-covered --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
}

gen_coverage_xml_report()
{
# This function saves the coverage report in Cobertura XML format, which is compatible with codecov
set_ray_coveragerc
coverage xml -o $fcoveragexml --fail-under=100 --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
coverage xml -o $fcoveragexml --fail-under=100 --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
}

test_custom()
Expand Down Expand Up @@ -333,6 +340,12 @@ check_print
check_naive
check_ray


if [[ -z $NUMBA_DISABLE_JIT || $NUMBA_DISABLE_JIT -eq 0 ]]; then
check_fastmath
fi


if [[ $test_mode == "notebooks" ]]; then
echo "Executing Tutorial Notebooks Only"
convert_notebooks
Expand Down
Loading