Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions fastmath.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python

import argparse
import importlib
import pathlib
import re


def get_njit_funcs():
"""
Retrieve a list of all njit functions
Parameters
----------
None
Returns
-------
njit_funcs : list
A list of all njit functions, where each element is a tuple of the form
(module_name, func_name)
"""
pattern = r"@njit.*?def\s+\w+\("

stumpy_path = pathlib.Path(__file__).parent / "stumpy"
filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())

out = []
ignore = ["__init__.py", "__pycache__"]
for filepath in filepaths:
fname = filepath.name
if fname not in ignore and fname.endswith(".py"):
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()

matches = re.findall(pattern, file_contents, re.DOTALL)
for match in matches:
func_name = match.split("def ")[-1].split("(")[0]
out.append((fname.removesuffix(".py"), func_name))

return out


def check_fastmath():
"""
Check if all njit functions have the `fastmath` flag set
Parameters
----------
None
Returns
-------
None
"""
missing_fastmath = [] # list of njit functions with missing fastmath flags
for module_name, func_name in get_njit_funcs():
module = importlib.import_module(f".{module_name}", package="stumpy")
func = getattr(module, func_name)
if "fastmath" not in func.targetoptions.keys():
missing_fastmath.append(f"{module_name}.{func_name}")

if len(missing_fastmath) > 0:
msg = "Found one or more functions that are missing the `fastmath` flag. "
msg += f"The function(s) are:\n {missing_fastmath}\n"
raise ValueError(msg)

return


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", action="store_true")
args = parser.parse_args()

if args.check:
check_fastmath()
15 changes: 10 additions & 5 deletions stumpy/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,16 @@ def get_njit_funcs():
if isinstance(node, ast.FunctionDef):
func_name = node.name
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
if decorator.func.id == "njit":
njit_funcs.append((module_name, func_name))
scenario_1 = isinstance(decorator, ast.Name) and (
decorator.id == "njit"
)
scenario_2 = (
isinstance(decorator, ast.Call)
and isinstance(decorator.func, ast.Name)
and decorator.func.id == "njit"
)
if scenario_1 or scenario_2:
njit_funcs.append((module_name, func_name))

return njit_funcs

Expand Down
17 changes: 15 additions & 2 deletions test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ check_print()
fi
}

check_fastmath()
{
echo "Checking Missing fastmath flags in njit functions"
./fastmath.py --check
check_errs $?
}

check_naive()
{
# Check if there are any naive implementations not at start of test file
Expand Down Expand Up @@ -146,14 +153,14 @@ set_ray_coveragerc()
show_coverage_report()
{
set_ray_coveragerc
coverage report -m --fail-under=100 --skip-covered --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
coverage report -m --fail-under=100 --skip-covered --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
}

gen_coverage_xml_report()
{
# This function saves the coverage report in Cobertura XML format, which is compatible with codecov
set_ray_coveragerc
coverage xml -o $fcoveragexml --fail-under=100 --omit=docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
coverage xml -o $fcoveragexml --fail-under=100 --omit=fastmath.py,docstring.py,min_versions.py,ray_python_version.py,stumpy/cache.py $fcoveragerc
}

test_custom()
Expand Down Expand Up @@ -333,6 +340,12 @@ check_print
check_naive
check_ray


if [[ -z $NUMBA_DISABLE_JIT || $NUMBA_DISABLE_JIT -eq 0 ]]; then
check_fastmath
fi


if [[ $test_mode == "notebooks" ]]; then
echo "Executing Tutorial Notebooks Only"
convert_notebooks
Expand Down
Loading