66import pathlib
77
88
9- def get_njit_funcs (pkg_dir ):
9+ def get_func_nodes (pkg_dir ):
1010 """
11- Identify all njit functions
11+ Retrun a dictionary where the keys are the module names and the values are
12+ the function AST nodes
1213
1314 Parameters
1415 ----------
1516 pkg_dir : str
16- The path to the directory containing some .py files
17+ The path to the directory containing some .py files
1718
1819 Returns
1920 -------
20- njit_funcs : list
21- A list of all njit functions, where each element is a tuple of the form
22- (module_name, func_name)
21+ out : dict
22+ A dictionary where the keys are the module names and the values are a list of
23+ AST nodes for each njit function in the module
2324 """
2425 ignore_py_files = ["__init__" , "__pycache__" ]
2526 pkg_dir = pathlib .Path (pkg_dir )
@@ -29,29 +30,56 @@ def get_njit_funcs(pkg_dir):
2930 if fname .stem not in ignore_py_files and not fname .stem .startswith ("." ):
3031 module_names .append (fname .stem )
3132
32- njit_funcs = []
33+ out = {}
3334 for module_name in module_names :
3435 filepath = pkg_dir / f"{ module_name } .py"
3536 file_contents = ""
3637 with open (filepath , encoding = "utf8" ) as f :
3738 file_contents = f .read ()
3839 module = ast .parse (file_contents )
40+
41+ module_funcs_nodes = []
3942 for node in module .body :
4043 if isinstance (node , ast .FunctionDef ):
41- func_name = node .name
42- for decorator in node .decorator_list :
43- decorator_name = None
44- if isinstance (decorator , ast .Name ):
45- # Bare decorator
46- decorator_name = decorator .id
47- if isinstance (decorator , ast .Call ) and isinstance (
48- decorator .func , ast .Name
49- ):
50- # Decorator is a function
51- decorator_name = decorator .func .id
52-
53- if decorator_name == "njit" :
54- njit_funcs .append ((module_name , func_name ))
44+ module_funcs_nodes .append (node )
45+ out [module_name ] = module_funcs_nodes
46+
47+ return out
48+
49+
50+ def get_njit_funcs (pkg_dir ):
51+ """
52+ Identify all njit functions
53+
54+ Parameters
55+ ----------
56+ pkg_dir : str
57+ The path to the directory containing some .py files
58+
59+ Returns
60+ -------
61+ njit_funcs : list
62+ A list of all njit functions, where each element is a tuple of the form
63+ (module_name, func_name)
64+ """
65+ njit_funcs = []
66+ modules_funcs_nodes = get_func_nodes (pkg_dir )
67+ for module_name , func_nodes in modules_funcs_nodes .items ():
68+ for node in func_nodes :
69+ func_name = node .name
70+ for decorator in node .decorator_list :
71+ decorator_name = None
72+ if isinstance (decorator , ast .Name ):
73+ # Bare decorator
74+ decorator_name = decorator .id
75+ if isinstance (decorator , ast .Call ) and isinstance (
76+ decorator .func , ast .Name
77+ ):
78+ # Decorator is a function
79+ decorator_name = decorator .func .id
80+
81+ if decorator_name == "njit" :
82+ njit_funcs .append ((module_name , func_name ))
5583
5684 return njit_funcs
5785
@@ -89,6 +117,60 @@ def check_fastmath(pkg_dir, pkg_name):
89117 return
90118
91119
120+ def check_hardcoded_flag (pkg_dir , pkg_name ):
121+ """
122+ Check if all `fastmath` flags are set to a config variable
123+
124+ Parameters
125+ ----------
126+ pkg_dir : str
127+ The path to the directory containing some .py files
128+
129+ pkg_name : str
130+ The name of the package
131+
132+ Returns
133+ -------
134+ None
135+ """
136+ ignore = [("fastmath" , "_add_assoc" )]
137+
138+ hardcoded_fastmath = [] # list of njit functions with hardcoded fastmath flags
139+ modules_funcs_nodes = get_func_nodes (pkg_dir )
140+ for module_name , func_nodes in modules_funcs_nodes .items ():
141+ for node in func_nodes :
142+ if (module_name , node .name ) in ignore :
143+ continue
144+
145+ njit_decorator_func = None
146+ for decorator in node .decorator_list :
147+ if (
148+ isinstance (decorator , ast .Call )
149+ and isinstance (decorator .func , ast .Name )
150+ and decorator .func .id == "njit"
151+ ):
152+ njit_decorator_func = decorator
153+ break
154+
155+ if njit_decorator_func is None :
156+ continue
157+
158+ for kwrd in njit_decorator_func .keywords :
159+ if kwrd .arg == "fastmath" :
160+ value = kwrd .value .value
161+ if not hasattr (value , "id" ) or value .id != "config" :
162+ hardcoded_fastmath .append (f"{ module_name } .{ node .name } " )
163+
164+ if len (hardcoded_fastmath ) > 0 :
165+ msg = (
166+ "Found one or more `@njit()` functions with hardcoded `fastmath` flag. "
167+ + f"The functions are:\n { hardcoded_fastmath } \n "
168+ )
169+ raise ValueError (msg )
170+
171+ return
172+
173+
92174if __name__ == "__main__" :
93175 parser = argparse .ArgumentParser ()
94176 parser .add_argument ("--check" , dest = "pkg_dir" )
@@ -98,3 +180,4 @@ def check_fastmath(pkg_dir, pkg_name):
98180 pkg_dir = pathlib .Path (args .pkg_dir )
99181 pkg_name = pkg_dir .name
100182 check_fastmath (str (pkg_dir ), pkg_name )
183+ check_hardcoded_flag (str (pkg_dir ), pkg_name )
0 commit comments