Skip to content

Commit 08b46c6

Browse files
committed
minor changes
1 parent 69995ca commit 08b46c6

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

fastmath.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class FunctionCallVisitor(ast.NodeVisitor):
152152
153153
push_out()
154154
Push the current function call stack onto the output list if it is not
155-
included in one of the existing call stacks in `self.out`.
155+
included in one of the so-far-collected call stacks.
156156
157157
visit_Call(node)
158158
Visit an AST node of type `ast.Call`. This method is called when the visitor
@@ -421,16 +421,16 @@ def get_njit_call_stacks(pkg_dir, pkg_name):
421421
return visitor.out
422422

423423

424-
def check_fastmath_callstack(pkg_dir, pkg_name):
424+
def check_call_stack_fastmath(pkg_dir, pkg_name):
425425
"""
426-
Check if all njit functions in a callstack have the same `fastmath` flag.
426+
Check if all njit functions in a call stack have the same `fastmath` flag.
427427
This function raises a ValueError if it finds any inconsistencies in the
428-
`fastmath` flags across the call stacks of njit functions.
428+
`fastmath` flags in any call stack of njit functions.
429429
430430
Parameters
431431
----------
432432
pkg_dir : str
433-
The path to the package directory containing some .py files
433+
The path to the directory containing some .py files
434434
435435
pkg_name : str
436436
The name of the package
@@ -441,26 +441,27 @@ def check_fastmath_callstack(pkg_dir, pkg_name):
441441
"""
442442
out = get_njit_call_stacks(pkg_dir, pkg_name)
443443

444-
fastmath_is_inconsistent = []
444+
inconsitent_call_stacks = []
445445
for cs in out:
446+
# Set the fastmath flag of the first function in the call stack
447+
# as the reference flag
446448
module_name, func_name = cs[0].split(".")
447449
module = importlib.import_module(f".{module_name}", package="stumpy")
448450
func = getattr(module, func_name)
449-
flag = func.targetoptions["fastmath"]
451+
flag_ref = func.targetoptions["fastmath"]
450452

451453
for item in cs[1:]:
452454
module_name, func_name = cs[0].split(".")
453455
module = importlib.import_module(f".{module_name}", package="stumpy")
454456
func = getattr(module, func_name)
455-
func_flag = func.targetoptions["fastmath"]
456-
if func_flag != flag:
457-
fastmath_is_inconsistent.append(cs)
457+
if func.targetoptions["fastmath"] != flag_ref:
458+
inconsitent_call_stacks.append(cs)
458459
break
459460

460-
if len(fastmath_is_inconsistent) > 0:
461+
if len(inconsitent_call_stacks) > 0:
461462
msg = (
462463
"Found at least one callstack that have inconsistent `fastmath` flags. "
463-
+ f"The functions are:\n {fastmath_is_inconsistent}\n"
464+
+ f"The functions are:\n {inconsitent_call_stacks}\n"
464465
)
465466
raise ValueError(msg)
466467

@@ -476,4 +477,4 @@ def check_fastmath_callstack(pkg_dir, pkg_name):
476477
pkg_dir = pathlib.Path(args.pkg_dir)
477478
pkg_name = pkg_dir.name
478479
check_fastmath(str(pkg_dir), pkg_name)
479-
check_fastmath_callstack(str(pkg_dir), pkg_name)
480+
check_call_stack_fastmath(str(pkg_dir), pkg_name)

0 commit comments

Comments
 (0)