Skip to content

Commit 5fa5e1f

Browse files
committed
minor changes
1 parent d6740ed commit 5fa5e1f

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

fastmath.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ class FunctionCallVisitor(ast.NodeVisitor):
105105
Attributes
106106
----------
107107
module_names : list
108-
A list of module names to track the modules as the visitor traverses their AST
108+
A list of module names to track the modules as the visitor traverses their AST.
109109
110110
call_stack : list
111111
A list of njit functions, representing a chain of function calls,
@@ -289,7 +289,7 @@ def pop_call_stack(self):
289289
def goto_deeper_func(self, node):
290290
"""
291291
Calls the visit method from class `ast.NodeVisitor` on
292-
all children of the node.
292+
all children of the `node`.
293293
294294
Parameters
295295
----------
@@ -324,8 +324,9 @@ def goto_next_func(self, node):
324324

325325
def push_out(self):
326326
"""
327-
Push the current function call stack onto the output list if it is not
328-
included in one of the existing call stacks in `self.out`.
327+
Push the current function call stack onto the output list unless it
328+
is already included in one of the so-far-collected call stacks.
329+
329330
330331
Parameters
331332
----------
@@ -348,7 +349,7 @@ def push_out(self):
348349

349350
def visit_Call(self, node):
350351
"""
351-
Visit an AST node of type `ast.Call`.
352+
Called when visiting an AST node of type `ast.Call`.
352353
353354
Parameters
354355
----------
@@ -378,10 +379,10 @@ def visit_Call(self, node):
378379
callee_node = self.njit_nodes[callee_name]
379380
self.push_call_stack(new_module_name, new_func_name)
380381
self.goto_deeper_func(callee_node)
381-
if module_changed:
382-
self.pop_module()
383382
self.push_out()
384383
self.pop_call_stack()
384+
if module_changed:
385+
self.pop_module()
385386

386387
self.goto_next_func(node)
387388

@@ -390,9 +391,9 @@ def visit_Call(self, node):
390391

391392
def get_njit_call_stacks(pkg_dir, pkg_name):
392393
"""
393-
Get the call stacks of all njit functions in STUMPY.
394-
This function traverses the AST of each module in STUMPY and returns
395-
a list of unique function call stacks.
394+
Get the call stacks of all njit functions in `pkg_dir`.
395+
This function traverses the AST of each module in `pkg_dir`
396+
and returns a list of unique function call stacks.
396397
397398
Parameters
398399
----------
@@ -405,8 +406,8 @@ def get_njit_call_stacks(pkg_dir, pkg_name):
405406
Returns
406407
-------
407408
out : list
408-
A list of unique function call stacks. Each element is a list of strings,
409-
where each string represents a function call in the stack.
409+
A list of unique function call stacks. Each item is of type list,
410+
representing a chain of function calls.
410411
"""
411412
visitor = FunctionCallVisitor(pkg_dir, pkg_name)
412413

@@ -430,7 +431,7 @@ def check_call_stack_fastmath(pkg_dir, pkg_name):
430431
"""
431432
Check if all njit functions in a call stack have the same `fastmath` flag.
432433
This function raises a ValueError if it finds any inconsistencies in the
433-
`fastmath` flags in any call stack of njit functions.
434+
`fastmath` flags in at lease one call stack of njit functions.
434435
435436
Parameters
436437
----------
@@ -444,10 +445,10 @@ def check_call_stack_fastmath(pkg_dir, pkg_name):
444445
-------
445446
None
446447
"""
447-
out = get_njit_call_stacks(pkg_dir, pkg_name)
448-
449448
inconsitent_call_stacks = []
450-
for cs in out:
449+
450+
njit_call_stacks = get_njit_call_stacks(pkg_dir, pkg_name)
451+
for cs in njit_call_stacks:
451452
# Set the fastmath flag of the first function in the call stack
452453
# as the reference flag
453454
module_name, func_name = cs[0].split(".")
@@ -459,14 +460,15 @@ def check_call_stack_fastmath(pkg_dir, pkg_name):
459460
module_name, func_name = cs[0].split(".")
460461
module = importlib.import_module(f".{module_name}", package="stumpy")
461462
func = getattr(module, func_name)
462-
if func.targetoptions["fastmath"] != flag_ref:
463+
flag = func.targetoptions["fastmath"]
464+
if flag != flag_ref:
463465
inconsitent_call_stacks.append(cs)
464466
break
465467

466468
if len(inconsitent_call_stacks) > 0:
467469
msg = (
468470
"Found at least one callstack that have inconsistent `fastmath` flags. "
469-
+ f"The functions are:\n {inconsitent_call_stacks}\n"
471+
+ f"Those call stacks are:\n {inconsitent_call_stacks}\n"
470472
)
471473
raise ValueError(msg)
472474

0 commit comments

Comments
 (0)