@@ -89,6 +89,384 @@ def check_fastmath(pkg_dir, pkg_name):
8989 return
9090
9191
92+ class FunctionCallVisitor (ast .NodeVisitor ):
93+ """
94+ A class to traverse the AST of modules of a package to collect the call stacks
95+ of njit functions.
96+
97+ Parameters
98+ ----------
99+ pkg_dir : str
100+ The path to the package directory containing some .py files.
101+
102+ pkg_name : str
103+ The name of the package.
104+
105+ Attributes
106+ ----------
107+ module_names : list
108+ A list of module names to track the modules as the visitor traverses their AST
109+
110+ call_stack : list
111+ A list of function calls made in the current module
112+
113+ out : list
114+ A list of unique function call stacks.
115+
116+ njit_funcs : list
117+ A list of njit functions in STUMPY. Each element is a tuple of the form
118+ (module_name, func_name).
119+
120+ njit_modules : set
121+ A set of module names, where each contains at least one njit function.
122+
123+ njit_nodes : dict
124+ A dictionary mapping njit function names to their corresponding AST nodes.
125+ A key is of the form "module_name.func_name", and its corresponding value
126+ is the AST node- with type ast.FunctionDef- of that njit function
127+
128+ ast_modules : dict
129+ A dictionary mapping module names to their corresponding AST objects. A key
130+ is of the form "module_name", and its corresponding value is the content of
131+ the module as an AST object.
132+
133+ Methods
134+ -------
135+ push_module(module_name)
136+ Push a module name onto the stack of module names.
137+
138+ pop_module()
139+ Pop the last module name from the stack of module names.
140+
141+ push_call_stack(module_name, func_name)
142+ Push a function call onto the stack of function calls.
143+
144+ pop_call_stack()
145+ Pop the last function call from the stack of function calls.
146+
147+ goto_deeper_func(node)
148+ Calls the visit method from class `ast.NodeVisitor` on all children of the node.
149+
150+ goto_next_func(node)
151+ Calls the visit method from class `ast.NodeVisitor` on all children of the node.
152+
153+ push_out()
154+ Push the current function call stack onto the output list if it is not
155+ included in one of the existing call stacks in `self.out`.
156+
157+ visit_Call(node)
158+ Visit an AST node of type `ast.Call`. This method is called when the visitor
159+ encounters a function call in the AST. It checks if the called function is
160+ a njit function and, if so, traverses its AST to collect its call stack.
161+ """
162+
163+ def __init__ (self , pkg_dir , pkg_name ):
164+ """
165+ Initialize the FunctionCallVisitor class. This method sets up the necessary
166+ attributes and prepares the visitor for traversing the AST of STUMPY's modules.
167+
168+ Parameters
169+ ----------
170+ pkg_dir : str
171+ The path to the package directory containing some .py files.
172+
173+ pkg_name : str
174+ The name of the package.
175+
176+ Returns
177+ -------
178+ None
179+ """
180+ super ().__init__ ()
181+ self .module_names = []
182+ self .call_stack = []
183+ self .out = []
184+
185+ # Setup lists, dicts, and ast objects
186+ self .njit_funcs = get_njit_funcs (pkg_dir )
187+ self .njit_modules = set (mod_name for mod_name , func_name in self .njit_funcs )
188+ self .njit_nodes = {}
189+ self .ast_modules = {}
190+
191+ filepaths = sorted (f for f in pathlib .Path (pkg_dir ).iterdir () if f .is_file ())
192+ ignore = ["__init__.py" , "__pycache__" ]
193+
194+ for filepath in filepaths :
195+ file_name = filepath .name
196+ if (
197+ file_name not in ignore
198+ and not file_name .startswith ("gpu" )
199+ and str (filepath ).endswith (".py" )
200+ ):
201+ module_name = file_name .replace (".py" , "" )
202+ file_contents = ""
203+ with open (filepath , encoding = "utf8" ) as f :
204+ file_contents = f .read ()
205+ self .ast_modules [module_name ] = ast .parse (file_contents )
206+
207+ for node in self .ast_modules [module_name ].body :
208+ if isinstance (node , ast .FunctionDef ):
209+ func_name = node .name
210+ if (module_name , func_name ) in self .njit_funcs :
211+ self .njit_nodes [f"{ module_name } .{ func_name } " ] = node
212+
213+ def push_module (self , module_name ):
214+ """
215+ Push a module name onto the stack of module names.
216+
217+ Parameters
218+ ----------
219+ module_name : str
220+ The name of the module to be pushed onto the stack.
221+
222+ Returns
223+ -------
224+ None
225+ """
226+ self .module_names .append (module_name )
227+
228+ return
229+
230+ def pop_module (self ):
231+ """
232+ Pop the last module name from the stack of module names.
233+
234+ Parameters
235+ ----------
236+ None
237+
238+ Returns
239+ -------
240+ None
241+ """
242+ if self .module_names :
243+ self .module_names .pop ()
244+
245+ return
246+
247+ def push_call_stack (self , module_name , func_name ):
248+ """
249+ Push a function call onto the stack of function calls.
250+
251+ Parameters
252+ ----------
253+ module_name : str
254+ The name of the module containing the function being called.
255+
256+ func_name : str
257+ The name of the function being called.
258+
259+ Returns
260+ -------
261+ None
262+ """
263+ self .call_stack .append (f"{ module_name } .{ func_name } " )
264+
265+ return
266+
267+ def pop_call_stack (self ):
268+ """
269+ Pop the last function call from the stack of function calls.
270+
271+ Parameters
272+ ----------
273+ None
274+
275+ Returns
276+ -------
277+ None
278+ """
279+ if self .call_stack :
280+ self .call_stack .pop ()
281+
282+ return
283+
284+ def goto_deeper_func (self , node ):
285+ """
286+ Calls the visit method from class `ast.NodeVisitor` on
287+ all children of the node.
288+
289+ Parameters
290+ ----------
291+ node : ast.AST
292+ The AST node to be visited.
293+
294+ Returns
295+ -------
296+ None
297+ """
298+ self .generic_visit (node )
299+
300+ return
301+
302+ def goto_next_func (self , node ):
303+ """
304+ Calls the visit method from class `ast.NodeVisitor` on
305+ all children of the node.
306+
307+ Parameters
308+ ----------
309+ node : ast.AST
310+ The AST node to be visited.
311+
312+ Returns
313+ -------
314+ None
315+ """
316+ self .generic_visit (node )
317+
318+ return
319+
320+ def push_out (self ):
321+ """
322+ Push the current function call stack onto the output list if it is not
323+ included in one of the existing call stacks in `self.out`.
324+
325+ Parameters
326+ ----------
327+ None
328+
329+ Returns
330+ -------
331+ None
332+ """
333+ unique = True
334+ for cs in self .out :
335+ if " " .join (self .call_stack ) in " " .join (cs ):
336+ unique = False
337+ break
338+
339+ if unique :
340+ self .out .append (self .call_stack .copy ())
341+
342+ return
343+
344+ def visit_Call (self , node ):
345+ """
346+ Visit an AST node of type `ast.Call`.
347+
348+ Parameters
349+ ----------
350+ node : ast.Call
351+ The AST node representing a function call.
352+
353+ Returns
354+ -------
355+ None
356+ """
357+ callee_name = ast .unparse (node .func )
358+
359+ module_changed = False
360+ if "." in callee_name :
361+ new_module_name , new_func_name = callee_name .split ("." )[:2 ]
362+
363+ if new_module_name in self .njit_modules :
364+ self .push_module (new_module_name )
365+ module_changed = True
366+ else :
367+ if self .module_names :
368+ new_module_name = self .module_names [- 1 ]
369+ new_func_name = callee_name
370+ callee_name = f"{ new_module_name } .{ new_func_name } "
371+
372+ if callee_name in self .njit_nodes .keys ():
373+ callee_node = self .njit_nodes [callee_name ]
374+ self .push_call_stack (new_module_name , new_func_name )
375+ self .goto_deeper_func (callee_node )
376+ if module_changed :
377+ self .pop_module ()
378+ self .push_out ()
379+ self .pop_call_stack ()
380+
381+ self .goto_next_func (node )
382+
383+ return
384+
385+
386+ def get_njit_call_stacks (pkg_dir , pkg_name ):
387+ """
388+ Get the call stacks of all njit functions in STUMPY.
389+ This function traverses the AST of each module in STUMPY and returns
390+ a list of unique function call stacks.
391+
392+ Parameters
393+ ----------
394+ pkg_dir : str
395+ The path to the package directory containing some .py files
396+
397+ pkg_name : str
398+ The name of the package
399+
400+ Returns
401+ -------
402+ out : list
403+ A list of unique function call stacks. Each element is a list of strings,
404+ where each string represents a function call in the stack.
405+ """
406+ visitor = FunctionCallVisitor (pkg_dir , pkg_name )
407+
408+ for module_name in visitor .njit_modules :
409+ visitor .push_module (module_name )
410+
411+ for node in visitor .ast_modules [module_name ].body :
412+ if isinstance (node , ast .FunctionDef ):
413+ func_name = node .name
414+ if (module_name , func_name ) in visitor .njit_funcs :
415+ visitor .push_call_stack (module_name , func_name )
416+ visitor .visit (node )
417+ visitor .pop_call_stack ()
418+
419+ visitor .pop_module ()
420+
421+ return visitor .out
422+
423+
424+ def check_fastmath_callstack (pkg_dir , pkg_name ):
425+ """
426+ Check if all njit functions in a callstack have the same `fastmath` flag.
427+ This function raises a ValueError if it finds any inconsistencies in the
428+ `fastmath` flags across the call stacks of njit functions.
429+
430+ Parameters
431+ ----------
432+ pkg_dir : str
433+ The path to the package directory containing some .py files
434+
435+ pkg_name : str
436+ The name of the package
437+
438+ Returns
439+ -------
440+ None
441+ """
442+ out = get_njit_call_stacks (pkg_dir , pkg_name )
443+
444+ fastmath_is_inconsistent = []
445+ for cs in out :
446+ module_name , func_name = cs [0 ].split ("." )
447+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
448+ func = getattr (module , func_name )
449+ flag = func .targetoptions ["fastmath" ]
450+
451+ for item in cs [1 :]:
452+ module_name , func_name = cs [0 ].split ("." )
453+ module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
454+ func = getattr (module , func_name )
455+ func_flag = func .targetoptions ["fastmath" ]
456+ if func_flag != flag :
457+ fastmath_is_inconsistent .append (cs )
458+ break
459+
460+ if len (fastmath_is_inconsistent ) > 0 :
461+ msg = (
462+ "Found at least one callstack that have inconsistent `fastmath` flags. "
463+ + f"The functions are:\n { fastmath_is_inconsistent } \n "
464+ )
465+ raise ValueError (msg )
466+
467+ return
468+
469+
92470if __name__ == "__main__" :
93471 parser = argparse .ArgumentParser ()
94472 parser .add_argument ("--check" , dest = "pkg_dir" )
@@ -98,3 +476,4 @@ def check_fastmath(pkg_dir, pkg_name):
98476 pkg_dir = pathlib .Path (args .pkg_dir )
99477 pkg_name = pkg_dir .name
100478 check_fastmath (str (pkg_dir ), pkg_name )
479+ check_fastmath_callstack (str (pkg_dir ), pkg_name )
0 commit comments