Skip to content

Commit 69995ca

Browse files
committed
Add check for fastmath flags of callstacks
1 parent 9bdbae9 commit 69995ca

File tree

1 file changed

+379
-0
lines changed

1 file changed

+379
-0
lines changed

fastmath.py

Lines changed: 379 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,384 @@ def check_fastmath(pkg_dir, pkg_name):
8989
return
9090

9191

92+
class FunctionCallVisitor(ast.NodeVisitor):
93+
"""
94+
A class to traverse the AST of modules of a package to collect the call stacks
95+
of njit functions.
96+
97+
Parameters
98+
----------
99+
pkg_dir : str
100+
The path to the package directory containing some .py files.
101+
102+
pkg_name : str
103+
The name of the package.
104+
105+
Attributes
106+
----------
107+
module_names : list
108+
A list of module names to track the modules as the visitor traverses their AST
109+
110+
call_stack : list
111+
A list of function calls made in the current module
112+
113+
out : list
114+
A list of unique function call stacks.
115+
116+
njit_funcs : list
117+
A list of njit functions in STUMPY. Each element is a tuple of the form
118+
(module_name, func_name).
119+
120+
njit_modules : set
121+
A set of module names, where each contains at least one njit function.
122+
123+
njit_nodes : dict
124+
A dictionary mapping njit function names to their corresponding AST nodes.
125+
A key is of the form "module_name.func_name", and its corresponding value
126+
is the AST node- with type ast.FunctionDef- of that njit function
127+
128+
ast_modules : dict
129+
A dictionary mapping module names to their corresponding AST objects. A key
130+
is of the form "module_name", and its corresponding value is the content of
131+
the module as an AST object.
132+
133+
Methods
134+
-------
135+
push_module(module_name)
136+
Push a module name onto the stack of module names.
137+
138+
pop_module()
139+
Pop the last module name from the stack of module names.
140+
141+
push_call_stack(module_name, func_name)
142+
Push a function call onto the stack of function calls.
143+
144+
pop_call_stack()
145+
Pop the last function call from the stack of function calls.
146+
147+
goto_deeper_func(node)
148+
Calls the visit method from class `ast.NodeVisitor` on all children of the node.
149+
150+
goto_next_func(node)
151+
Calls the visit method from class `ast.NodeVisitor` on all children of the node.
152+
153+
push_out()
154+
Push the current function call stack onto the output list if it is not
155+
included in one of the existing call stacks in `self.out`.
156+
157+
visit_Call(node)
158+
Visit an AST node of type `ast.Call`. This method is called when the visitor
159+
encounters a function call in the AST. It checks if the called function is
160+
a njit function and, if so, traverses its AST to collect its call stack.
161+
"""
162+
163+
def __init__(self, pkg_dir, pkg_name):
164+
"""
165+
Initialize the FunctionCallVisitor class. This method sets up the necessary
166+
attributes and prepares the visitor for traversing the AST of STUMPY's modules.
167+
168+
Parameters
169+
----------
170+
pkg_dir : str
171+
The path to the package directory containing some .py files.
172+
173+
pkg_name : str
174+
The name of the package.
175+
176+
Returns
177+
-------
178+
None
179+
"""
180+
super().__init__()
181+
self.module_names = []
182+
self.call_stack = []
183+
self.out = []
184+
185+
# Setup lists, dicts, and ast objects
186+
self.njit_funcs = get_njit_funcs(pkg_dir)
187+
self.njit_modules = set(mod_name for mod_name, func_name in self.njit_funcs)
188+
self.njit_nodes = {}
189+
self.ast_modules = {}
190+
191+
filepaths = sorted(f for f in pathlib.Path(pkg_dir).iterdir() if f.is_file())
192+
ignore = ["__init__.py", "__pycache__"]
193+
194+
for filepath in filepaths:
195+
file_name = filepath.name
196+
if (
197+
file_name not in ignore
198+
and not file_name.startswith("gpu")
199+
and str(filepath).endswith(".py")
200+
):
201+
module_name = file_name.replace(".py", "")
202+
file_contents = ""
203+
with open(filepath, encoding="utf8") as f:
204+
file_contents = f.read()
205+
self.ast_modules[module_name] = ast.parse(file_contents)
206+
207+
for node in self.ast_modules[module_name].body:
208+
if isinstance(node, ast.FunctionDef):
209+
func_name = node.name
210+
if (module_name, func_name) in self.njit_funcs:
211+
self.njit_nodes[f"{module_name}.{func_name}"] = node
212+
213+
def push_module(self, module_name):
214+
"""
215+
Push a module name onto the stack of module names.
216+
217+
Parameters
218+
----------
219+
module_name : str
220+
The name of the module to be pushed onto the stack.
221+
222+
Returns
223+
-------
224+
None
225+
"""
226+
self.module_names.append(module_name)
227+
228+
return
229+
230+
def pop_module(self):
231+
"""
232+
Pop the last module name from the stack of module names.
233+
234+
Parameters
235+
----------
236+
None
237+
238+
Returns
239+
-------
240+
None
241+
"""
242+
if self.module_names:
243+
self.module_names.pop()
244+
245+
return
246+
247+
def push_call_stack(self, module_name, func_name):
248+
"""
249+
Push a function call onto the stack of function calls.
250+
251+
Parameters
252+
----------
253+
module_name : str
254+
The name of the module containing the function being called.
255+
256+
func_name : str
257+
The name of the function being called.
258+
259+
Returns
260+
-------
261+
None
262+
"""
263+
self.call_stack.append(f"{module_name}.{func_name}")
264+
265+
return
266+
267+
def pop_call_stack(self):
268+
"""
269+
Pop the last function call from the stack of function calls.
270+
271+
Parameters
272+
----------
273+
None
274+
275+
Returns
276+
-------
277+
None
278+
"""
279+
if self.call_stack:
280+
self.call_stack.pop()
281+
282+
return
283+
284+
def goto_deeper_func(self, node):
285+
"""
286+
Calls the visit method from class `ast.NodeVisitor` on
287+
all children of the node.
288+
289+
Parameters
290+
----------
291+
node : ast.AST
292+
The AST node to be visited.
293+
294+
Returns
295+
-------
296+
None
297+
"""
298+
self.generic_visit(node)
299+
300+
return
301+
302+
def goto_next_func(self, node):
303+
"""
304+
Calls the visit method from class `ast.NodeVisitor` on
305+
all children of the node.
306+
307+
Parameters
308+
----------
309+
node : ast.AST
310+
The AST node to be visited.
311+
312+
Returns
313+
-------
314+
None
315+
"""
316+
self.generic_visit(node)
317+
318+
return
319+
320+
def push_out(self):
321+
"""
322+
Push the current function call stack onto the output list if it is not
323+
included in one of the existing call stacks in `self.out`.
324+
325+
Parameters
326+
----------
327+
None
328+
329+
Returns
330+
-------
331+
None
332+
"""
333+
unique = True
334+
for cs in self.out:
335+
if " ".join(self.call_stack) in " ".join(cs):
336+
unique = False
337+
break
338+
339+
if unique:
340+
self.out.append(self.call_stack.copy())
341+
342+
return
343+
344+
def visit_Call(self, node):
345+
"""
346+
Visit an AST node of type `ast.Call`.
347+
348+
Parameters
349+
----------
350+
node : ast.Call
351+
The AST node representing a function call.
352+
353+
Returns
354+
-------
355+
None
356+
"""
357+
callee_name = ast.unparse(node.func)
358+
359+
module_changed = False
360+
if "." in callee_name:
361+
new_module_name, new_func_name = callee_name.split(".")[:2]
362+
363+
if new_module_name in self.njit_modules:
364+
self.push_module(new_module_name)
365+
module_changed = True
366+
else:
367+
if self.module_names:
368+
new_module_name = self.module_names[-1]
369+
new_func_name = callee_name
370+
callee_name = f"{new_module_name}.{new_func_name}"
371+
372+
if callee_name in self.njit_nodes.keys():
373+
callee_node = self.njit_nodes[callee_name]
374+
self.push_call_stack(new_module_name, new_func_name)
375+
self.goto_deeper_func(callee_node)
376+
if module_changed:
377+
self.pop_module()
378+
self.push_out()
379+
self.pop_call_stack()
380+
381+
self.goto_next_func(node)
382+
383+
return
384+
385+
386+
def get_njit_call_stacks(pkg_dir, pkg_name):
387+
"""
388+
Get the call stacks of all njit functions in STUMPY.
389+
This function traverses the AST of each module in STUMPY and returns
390+
a list of unique function call stacks.
391+
392+
Parameters
393+
----------
394+
pkg_dir : str
395+
The path to the package directory containing some .py files
396+
397+
pkg_name : str
398+
The name of the package
399+
400+
Returns
401+
-------
402+
out : list
403+
A list of unique function call stacks. Each element is a list of strings,
404+
where each string represents a function call in the stack.
405+
"""
406+
visitor = FunctionCallVisitor(pkg_dir, pkg_name)
407+
408+
for module_name in visitor.njit_modules:
409+
visitor.push_module(module_name)
410+
411+
for node in visitor.ast_modules[module_name].body:
412+
if isinstance(node, ast.FunctionDef):
413+
func_name = node.name
414+
if (module_name, func_name) in visitor.njit_funcs:
415+
visitor.push_call_stack(module_name, func_name)
416+
visitor.visit(node)
417+
visitor.pop_call_stack()
418+
419+
visitor.pop_module()
420+
421+
return visitor.out
422+
423+
424+
def check_fastmath_callstack(pkg_dir, pkg_name):
425+
"""
426+
Check if all njit functions in a callstack have the same `fastmath` flag.
427+
This function raises a ValueError if it finds any inconsistencies in the
428+
`fastmath` flags across the call stacks of njit functions.
429+
430+
Parameters
431+
----------
432+
pkg_dir : str
433+
The path to the package directory containing some .py files
434+
435+
pkg_name : str
436+
The name of the package
437+
438+
Returns
439+
-------
440+
None
441+
"""
442+
out = get_njit_call_stacks(pkg_dir, pkg_name)
443+
444+
fastmath_is_inconsistent = []
445+
for cs in out:
446+
module_name, func_name = cs[0].split(".")
447+
module = importlib.import_module(f".{module_name}", package="stumpy")
448+
func = getattr(module, func_name)
449+
flag = func.targetoptions["fastmath"]
450+
451+
for item in cs[1:]:
452+
module_name, func_name = cs[0].split(".")
453+
module = importlib.import_module(f".{module_name}", package="stumpy")
454+
func = getattr(module, func_name)
455+
func_flag = func.targetoptions["fastmath"]
456+
if func_flag != flag:
457+
fastmath_is_inconsistent.append(cs)
458+
break
459+
460+
if len(fastmath_is_inconsistent) > 0:
461+
msg = (
462+
"Found at least one callstack that have inconsistent `fastmath` flags. "
463+
+ f"The functions are:\n {fastmath_is_inconsistent}\n"
464+
)
465+
raise ValueError(msg)
466+
467+
return
468+
469+
92470
if __name__ == "__main__":
93471
parser = argparse.ArgumentParser()
94472
parser.add_argument("--check", dest="pkg_dir")
@@ -98,3 +476,4 @@ def check_fastmath(pkg_dir, pkg_name):
98476
pkg_dir = pathlib.Path(args.pkg_dir)
99477
pkg_name = pkg_dir.name
100478
check_fastmath(str(pkg_dir), pkg_name)
479+
check_fastmath_callstack(str(pkg_dir), pkg_name)

0 commit comments

Comments
 (0)