diff --git a/fastmath.py b/fastmath.py new file mode 100755 index 000000000..b6fea39af --- /dev/null +++ b/fastmath.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python + +import argparse +import ast +import importlib +import pathlib + + +def get_njit_funcs(pkg_dir): + """ + Identify all njit functions + + Parameters + ---------- + pkg_dir : str + The path to the directory containing some .py files + + Returns + ------- + njit_funcs : list + A list of all njit functions, where each element is a tuple of the form + (module_name, func_name) + """ + ignore_py_files = ["__init__", "__pycache__"] + pkg_dir = pathlib.Path(pkg_dir) + + module_names = [] + for fname in pkg_dir.iterdir(): + if fname.stem not in ignore_py_files and not fname.stem.startswith("."): + module_names.append(fname.stem) + + njit_funcs = [] + for module_name in module_names: + filepath = pkg_dir / f"{module_name}.py" + file_contents = "" + with open(filepath, encoding="utf8") as f: + file_contents = f.read() + module = ast.parse(file_contents) + for node in module.body: + if isinstance(node, ast.FunctionDef): + func_name = node.name + for decorator in node.decorator_list: + decorator_name = None + if isinstance(decorator, ast.Name): + # Bare decorator + decorator_name = decorator.id + if isinstance(decorator, ast.Call) and isinstance( + decorator.func, ast.Name + ): + # Decorator is a function + decorator_name = decorator.func.id + + if decorator_name == "njit": + njit_funcs.append((module_name, func_name)) + + return njit_funcs + + +def check_fastmath(pkg_dir, pkg_name): + """ + Check if all njit functions have the `fastmath` flag set + + Parameters + ---------- + pkg_dir : str + The path to the directory containing some .py files + + pkg_name : str + The name of the package + + Returns + ------- + None + """ + missing_fastmath = [] # list of njit functions with missing fastmath flags + for module_name, func_name in get_njit_funcs(pkg_dir): + module = importlib.import_module(f".{module_name}", package=pkg_name) + func = getattr(module, func_name) + if "fastmath" not in func.targetoptions.keys(): + missing_fastmath.append(f"{module_name}.{func_name}") + + if len(missing_fastmath) > 0: + msg = ( + "Found one or more `@njit` functions that are missing the `fastmath` flag. " + + f"The functions are:\n {missing_fastmath}\n" + ) + raise ValueError(msg) + + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--check", dest="pkg_dir") + args = parser.parse_args() + + if args.pkg_dir: + pkg_dir = pathlib.Path(args.pkg_dir) + pkg_name = pkg_dir.name + check_fastmath(str(pkg_dir), pkg_name) diff --git a/stumpy/cache.py b/stumpy/cache.py index 398387724..76fb685bc 100644 --- a/stumpy/cache.py +++ b/stumpy/cache.py @@ -5,7 +5,6 @@ import ast import importlib import pathlib -import pkgutil import site import warnings @@ -28,13 +27,17 @@ def get_njit_funcs(): out : list A list of (`module_name`, `func_name`) pairs """ + ignore_py_files = ["__init__", "__pycache__"] + pkg_dir = pathlib.Path(__file__).parent - module_names = [name for _, name, _ in pkgutil.iter_modules([str(pkg_dir)])] + module_names = [] + for fname in pkg_dir.iterdir(): + if fname.stem not in ignore_py_files and not fname.stem.startswith("."): + module_names.append(fname.stem) njit_funcs = [] - for module_name in module_names: - filepath = pathlib.Path(__file__).parent / f"{module_name}.py" + filepath = pkg_dir / f"{module_name}.py" file_contents = "" with open(filepath, encoding="utf8") as f: file_contents = f.read() @@ -43,11 +46,18 @@ def get_njit_funcs(): if isinstance(node, ast.FunctionDef): func_name = node.name for decorator in node.decorator_list: + decorator_name = None + if isinstance(decorator, ast.Name): + # Bare decorator + decorator_name = decorator.id if isinstance(decorator, ast.Call) and isinstance( decorator.func, ast.Name ): - if decorator.func.id == "njit": - njit_funcs.append((module_name, func_name)) + # Decorator is a function + decorator_name = decorator.func.id + + if decorator_name == "njit": + njit_funcs.append((module_name, func_name)) return njit_funcs diff --git a/stumpy/core.py b/stumpy/core.py index 4cdaea02a..7e8a18c42 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -2347,6 +2347,7 @@ def _count_diagonal_ndist(diags, m, n_A, n_B): @njit( # "i8[:, :](i8[:], i8, b1)" + fastmath=True ) def _get_array_ranges(a, n_chunks, truncate): """ @@ -2395,6 +2396,7 @@ def _get_array_ranges(a, n_chunks, truncate): @njit( # "i8[:, :](i8, i8, b1)" + fastmath=True ) def _get_ranges(size, n_chunks, truncate): """ @@ -3247,7 +3249,7 @@ def _select_P_ABBA_value(P_ABBA, k, custom_func=None): return MPdist -@njit() +@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"}) def _merge_topk_PI(PA, PB, IA, IB): """ Merge two top-k matrix profiles `PA` and `PB`, and update `PA` (in place). @@ -3320,7 +3322,7 @@ def _merge_topk_PI(PA, PB, IA, IB): IA[i] = tmp_I -@njit() +@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"}) def _merge_topk_ρI(ρA, ρB, IA, IB): """ Merge two top-k pearson profiles `ρA` and `ρB`, and update `ρA` (in place). @@ -3394,7 +3396,7 @@ def _merge_topk_ρI(ρA, ρB, IA, IB): IA[i] = tmp_I -@njit() +@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"}) def _shift_insert_at_index(a, idx, v, shift="right"): """ If `shift=right` (default), all elements in `a[idx:]` are shifted to the right by @@ -4370,7 +4372,7 @@ def get_ray_nworkers(ray_client): return int(ray_client.cluster_resources().get("CPU")) -@njit +@njit(fastmath={"nsz", "arcp", "contract", "afn", "reassoc"}) def _update_incremental_PI(D, P, I, excl_zone, n_appended=0): """ Given the 1D array distance profile, `D`, of the last subsequence of T, diff --git a/test.sh b/test.sh index 06fe9819c..0136e16f5 100755 --- a/test.sh +++ b/test.sh @@ -93,6 +93,13 @@ check_print() fi } +check_fastmath() +{ + echo "Checking Missing fastmath flags in njit functions" + ./fastmath.py --check stumpy + check_errs $? +} + check_naive() { # Check if there are any naive implementations not at start of test file @@ -146,14 +153,14 @@ set_ray_coveragerc() show_coverage_report() { set_ray_coveragerc - coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc + coverage report -m --fail-under=100 --skip-covered --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc } gen_coverage_xml_report() { # This function saves the coverage report in Cobertura XML format, which is compatible with codecov set_ray_coveragerc - coverage xml -o $fcoveragexml --fail-under=100 --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc + coverage xml -o $fcoveragexml --fail-under=100 --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc } test_custom() @@ -333,6 +340,12 @@ check_print check_naive check_ray + +if [[ -z $NUMBA_DISABLE_JIT || $NUMBA_DISABLE_JIT -eq 0 ]]; then + check_fastmath +fi + + if [[ $test_mode == "notebooks" ]]; then echo "Executing Tutorial Notebooks Only" convert_notebooks