@@ -105,7 +105,7 @@ class FunctionCallVisitor(ast.NodeVisitor):
105105 Attributes
106106 ----------
107107 module_names : list
108- A list of module names to track the modules as the visitor traverses their AST
108+ A list of module names to track the modules as the visitor traverses their AST.
109109
110110 call_stack : list
111111 A list of njit functions, representing a chain of function calls,
@@ -289,7 +289,7 @@ def pop_call_stack(self):
289289 def goto_deeper_func (self , node ):
290290 """
291291 Calls the visit method from class `ast.NodeVisitor` on
292- all children of the node.
292+ all children of the ` node` .
293293
294294 Parameters
295295 ----------
@@ -324,8 +324,9 @@ def goto_next_func(self, node):
324324
325325 def push_out (self ):
326326 """
327- Push the current function call stack onto the output list if it is not
328- included in one of the existing call stacks in `self.out`.
327+ Push the current function call stack onto the output list unless it
328+ is already included in one of the so-far-collected call stacks.
329+
329330
330331 Parameters
331332 ----------
@@ -348,7 +349,7 @@ def push_out(self):
348349
349350 def visit_Call (self , node ):
350351 """
351- Visit an AST node of type `ast.Call`.
352+ Called when visiting an AST node of type `ast.Call`.
352353
353354 Parameters
354355 ----------
@@ -378,10 +379,10 @@ def visit_Call(self, node):
378379 callee_node = self .njit_nodes [callee_name ]
379380 self .push_call_stack (new_module_name , new_func_name )
380381 self .goto_deeper_func (callee_node )
381- if module_changed :
382- self .pop_module ()
383382 self .push_out ()
384383 self .pop_call_stack ()
384+ if module_changed :
385+ self .pop_module ()
385386
386387 self .goto_next_func (node )
387388
@@ -390,9 +391,9 @@ def visit_Call(self, node):
390391
391392def get_njit_call_stacks (pkg_dir , pkg_name ):
392393 """
393- Get the call stacks of all njit functions in STUMPY .
394- This function traverses the AST of each module in STUMPY and returns
395- a list of unique function call stacks.
394+ Get the call stacks of all njit functions in `pkg_dir` .
395+ This function traverses the AST of each module in `pkg_dir`
396+ and returns a list of unique function call stacks.
396397
397398 Parameters
398399 ----------
@@ -405,8 +406,8 @@ def get_njit_call_stacks(pkg_dir, pkg_name):
405406 Returns
406407 -------
407408 out : list
408- A list of unique function call stacks. Each element is a list of strings ,
409- where each string represents a function call in the stack .
409+ A list of unique function call stacks. Each item is of type list ,
410+ representing a chain of function calls .
410411 """
411412 visitor = FunctionCallVisitor (pkg_dir , pkg_name )
412413
@@ -430,7 +431,7 @@ def check_call_stack_fastmath(pkg_dir, pkg_name):
430431 """
431432 Check if all njit functions in a call stack have the same `fastmath` flag.
432433 This function raises a ValueError if it finds any inconsistencies in the
433- `fastmath` flags in any call stack of njit functions.
434+ `fastmath` flags in at lease one call stack of njit functions.
434435
435436 Parameters
436437 ----------
@@ -444,10 +445,10 @@ def check_call_stack_fastmath(pkg_dir, pkg_name):
444445 -------
445446 None
446447 """
447- out = get_njit_call_stacks (pkg_dir , pkg_name )
448-
449448 inconsitent_call_stacks = []
450- for cs in out :
449+
450+ njit_call_stacks = get_njit_call_stacks (pkg_dir , pkg_name )
451+ for cs in njit_call_stacks :
451452 # Set the fastmath flag of the first function in the call stack
452453 # as the reference flag
453454 module_name , func_name = cs [0 ].split ("." )
@@ -459,14 +460,15 @@ def check_call_stack_fastmath(pkg_dir, pkg_name):
459460 module_name , func_name = cs [0 ].split ("." )
460461 module = importlib .import_module (f".{ module_name } " , package = "stumpy" )
461462 func = getattr (module , func_name )
462- if func .targetoptions ["fastmath" ] != flag_ref :
463+ flag = func .targetoptions ["fastmath" ]
464+ if flag != flag_ref :
463465 inconsitent_call_stacks .append (cs )
464466 break
465467
466468 if len (inconsitent_call_stacks ) > 0 :
467469 msg = (
468470 "Found at least one callstack that have inconsistent `fastmath` flags. "
469- + f"The functions are:\n { inconsitent_call_stacks } \n "
471+ + f"Those call stacks are:\n { inconsitent_call_stacks } \n "
470472 )
471473 raise ValueError (msg )
472474
0 commit comments