diff --git a/fastmath.py b/fastmath.py index fe7dc0b56..b46fd5e37 100755 --- a/fastmath.py +++ b/fastmath.py @@ -89,6 +89,391 @@ def check_fastmath(pkg_dir, pkg_name): return +class FunctionCallVisitor(ast.NodeVisitor): + """ + A class to traverse the AST of the modules of a package to collect + the call stacks of njit functions. + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files. + + pkg_name : str + The name of the package. + + Attributes + ---------- + module_names : list + A list of module names to track the modules as the visitor traverses them. + + call_stack : list + A list of njit functions, representing a chain of function calls, + where each element is a string of the form "module_name.func_name". + + out : list + A list of unique `call_stack`s. + + njit_funcs : list + A list of all njit functions in `pkg_dir`'s modules. Each element is a tuple + of the form `(module_name, func_name)`. + + njit_modules : set + A set that contains the names of all modules, each of which contains at least + one njit function. + + njit_nodes : dict + A dictionary mapping njit function names to their corresponding AST nodes. + A key is a string, and it is of the form "module_name.func_name", and its + corresponding value is the AST node- with type ast.FunctionDef- of that + function. + + ast_modules : dict + A dictionary mapping module names to their corresponding AST objects. A key + is the name of a module, and its corresponding value is the content of that + module as an AST object. + + Methods + ------- + push_module(module_name) + Push the name of a module onto the stack `module_names`. + + pop_module() + Pop the last module name from the stack `module_names`. + + push_call_stack(module_name, func_name) + Push a function call onto the stack of function calls, `call_stack`. + + pop_call_stack() + Pop the last function call from the stack of function calls, `call_stack` + + goto_deeper_func(node) + Calls the visit method from class `ast.NodeVisitor` on all children of + the `node`. + + goto_next_func(node) + Calls the visit method from class `ast.NodeVisitor` on all children of + the `node`. + + push_out() + Push the current function call stack, `call_stack`, onto the output list, `out`, + unless it is already included in one of the so-far-collected call stacks. + + visit_Call(node) + This method is called when the visitor encounters a function call in the AST. It + checks if the called function is a njit function and, if so, traverses its AST + to collect its call stack. + """ + + def __init__(self, pkg_dir, pkg_name): + """ + Initialize the FunctionCallVisitor class. This method sets up the necessary + attributes and prepares the visitor for traversing the AST of STUMPY's modules. + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files. + + pkg_name : str + The name of the package. + + Returns + ------- + None + """ + super().__init__() + self.module_names = [] + self.call_stack = [] + self.out = [] + + # Setup lists, dicts, and ast objects + self.njit_funcs = get_njit_funcs(pkg_dir) + self.njit_modules = set(mod_name for mod_name, func_name in self.njit_funcs) + self.njit_nodes = {} + self.ast_modules = {} + + filepaths = sorted(f for f in pathlib.Path(pkg_dir).iterdir() if f.is_file()) + ignore = ["__init__.py", "__pycache__"] + + for filepath in filepaths: + file_name = filepath.name + if ( + file_name not in ignore + and not file_name.startswith("gpu") + and str(filepath).endswith(".py") + ): + module_name = file_name.replace(".py", "") + file_contents = "" + with open(filepath, encoding="utf8") as f: + file_contents = f.read() + self.ast_modules[module_name] = ast.parse(file_contents) + + for node in self.ast_modules[module_name].body: + if isinstance(node, ast.FunctionDef): + func_name = node.name + if (module_name, func_name) in self.njit_funcs: + self.njit_nodes[f"{module_name}.{func_name}"] = node + + def push_module(self, module_name): + """ + Push a module name onto the stack of module names. + + Parameters + ---------- + module_name : str + The name of the module to be pushed onto the stack. + + Returns + ------- + None + """ + self.module_names.append(module_name) + + return + + def pop_module(self): + """ + Pop the last module name from the stack of module names. + + Parameters + ---------- + None + + Returns + ------- + None + """ + if self.module_names: + self.module_names.pop() + + return + + def push_call_stack(self, module_name, func_name): + """ + Push a function call onto the stack of function calls. + + Parameters + ---------- + module_name : str + A module's name + + func_name : str + A function's name + + Returns + ------- + None + """ + self.call_stack.append(f"{module_name}.{func_name}") + + return + + def pop_call_stack(self): + """ + Pop the last function call from the stack of function calls. + + Parameters + ---------- + None + + Returns + ------- + None + """ + if self.call_stack: + self.call_stack.pop() + + return + + def goto_deeper_func(self, node): + """ + Calls the visit method from class `ast.NodeVisitor` on + all children of the `node`. + + Parameters + ---------- + node : ast.AST + The AST node to be visited. + + Returns + ------- + None + """ + self.generic_visit(node) + + return + + def goto_next_func(self, node): + """ + Calls the visit method from class `ast.NodeVisitor` on + all children of the node. + + Parameters + ---------- + node : ast.AST + The AST node to be visited. + + Returns + ------- + None + """ + self.generic_visit(node) + + return + + def push_out(self): + """ + Push the current function call stack onto the output list unless it + is already included in one of the so-far-collected call stacks. + + + Parameters + ---------- + None + + Returns + ------- + None + """ + unique = True + for cs in self.out: + if " ".join(self.call_stack) in " ".join(cs): + unique = False + break + + if unique: + self.out.append(self.call_stack.copy()) + + return + + def visit_Call(self, node): + """ + Called when visiting an AST node of type `ast.Call`. + + Parameters + ---------- + node : ast.Call + The AST node representing a function call. + + Returns + ------- + None + """ + callee_name = ast.unparse(node.func) + + module_changed = False + if "." in callee_name: + new_module_name, new_func_name = callee_name.split(".")[:2] + + if new_module_name in self.njit_modules: + self.push_module(new_module_name) + module_changed = True + else: + if self.module_names: + new_module_name = self.module_names[-1] + new_func_name = callee_name + callee_name = f"{new_module_name}.{new_func_name}" + + if callee_name in self.njit_nodes.keys(): + callee_node = self.njit_nodes[callee_name] + self.push_call_stack(new_module_name, new_func_name) + self.goto_deeper_func(callee_node) + self.push_out() + self.pop_call_stack() + if module_changed: + self.pop_module() + + self.goto_next_func(node) + + return + + +def get_njit_call_stacks(pkg_dir, pkg_name): + """ + Get the call stacks of all njit functions in `pkg_dir` + + Parameters + ---------- + pkg_dir : str + The path to the package directory containing some .py files + + pkg_name : str + The name of the package + + Returns + ------- + out : list + A list of unique function call stacks. Each item is of type list, + representing a chain of function calls. + """ + visitor = FunctionCallVisitor(pkg_dir, pkg_name) + + for module_name in visitor.njit_modules: + visitor.push_module(module_name) + + for node in visitor.ast_modules[module_name].body: + if isinstance(node, ast.FunctionDef): + func_name = node.name + if (module_name, func_name) in visitor.njit_funcs: + visitor.push_call_stack(module_name, func_name) + visitor.visit(node) + visitor.pop_call_stack() + + visitor.pop_module() + + return visitor.out + + +def check_call_stack_fastmath(pkg_dir, pkg_name): + """ + Check if all njit functions in a call stack have the same `fastmath` flag. + This function raises a ValueError if it finds any inconsistencies in the + `fastmath` flags in at lease one call stack of njit functions. + + Parameters + ---------- + pkg_dir : str + The path to the directory containing some .py files + + pkg_name : str + The name of the package + + Returns + ------- + None + """ + # List of call stacks with inconsistent fastmath flags + inconsistent_call_stacks = [] + + njit_call_stacks = get_njit_call_stacks(pkg_dir, pkg_name) + for cs in njit_call_stacks: + # Set the fastmath flag of the first function in the call stack + # as the reference flag + module_name, func_name = cs[0].split(".") + module = importlib.import_module(f".{module_name}", package="stumpy") + func = getattr(module, func_name) + flag_ref = func.targetoptions["fastmath"] + + for item in cs[1:]: + module_name, func_name = cs[0].split(".") + module = importlib.import_module(f".{module_name}", package="stumpy") + func = getattr(module, func_name) + flag = func.targetoptions["fastmath"] + if flag != flag_ref: + inconsistent_call_stacks.append(cs) + break + + if len(inconsistent_call_stacks) > 0: + msg = ( + "Found at least one call stack that has inconsistent `fastmath` flags. " + + f"Those call stacks are:\n {inconsistent_call_stacks}\n" + ) + raise ValueError(msg) + + return + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--check", dest="pkg_dir") @@ -98,3 +483,4 @@ def check_fastmath(pkg_dir, pkg_name): pkg_dir = pathlib.Path(args.pkg_dir) pkg_name = pkg_dir.name check_fastmath(str(pkg_dir), pkg_name) + check_call_stack_fastmath(str(pkg_dir), pkg_name) diff --git a/stumpy/aamp.py b/stumpy/aamp.py index 1e4879bcc..74236a7bf 100644 --- a/stumpy/aamp.py +++ b/stumpy/aamp.py @@ -13,7 +13,7 @@ @njit( # "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8[:], i8, i8, i8, f8[:, :, :]," # "f8[:, :], f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_diagonal( T_A, @@ -186,7 +186,7 @@ def _compute_diagonal( @njit( # "(f8[:], f8[:], i8, b1[:], b1[:], i8[:], b1, i8)", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _aamp( T_A, diff --git a/stumpy/core.py b/stumpy/core.py index a7758c2fd..e5e6912a2 100644 --- a/stumpy/core.py +++ b/stumpy/core.py @@ -1111,7 +1111,7 @@ def _calculate_squared_distance( @njit( # "f8[:](i8, f8[:], f8, f8, f8[:], f8[:])", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _calculate_squared_distance_profile( m, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant @@ -1177,7 +1177,7 @@ def _calculate_squared_distance_profile( @njit( # "f8[:](i8, f8[:], f8, f8, f8[:], f8[:])", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def calculate_distance_profile( m, QT, μ_Q, σ_Q, M_T, Σ_T, Q_subseq_isconstant, T_subseq_isconstant @@ -1251,6 +1251,10 @@ def _p_norm_distance_profile(Q, T, p=2.0): ------- output : numpy.ndarray p-normalized distance profile between `Q` and `T` + + Notes + ----- + The special case `p==inf` is not supported. """ m = Q.shape[0] l = T.shape[0] - m + 1 @@ -1979,7 +1983,7 @@ def _get_QT(start, T_A, T_B, m): @njit( # ["(f8[:], i8, i8)", "(f8[:, :], i8, i8)"], - fastmath=config.STUMPY_FASTMATH_TRUE + fastmath=config.STUMPY_FASTMATH_FLAGS ) def _apply_exclusion_zone(a, idx, excl_zone, val): """ diff --git a/stumpy/maamp.py b/stumpy/maamp.py index dad6748c3..a216f9fc4 100644 --- a/stumpy/maamp.py +++ b/stumpy/maamp.py @@ -592,7 +592,7 @@ def _get_multi_p_norm(start, T, m, p=2.0): # "(i8, i8, i8, f8[:, :], f8[:, :], i8, i8, b1[:, :], b1[:, :], f8," # "f8[:, :], f8[:, :], f8[:, :])", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_multi_p_norm( d, diff --git a/stumpy/mstump.py b/stumpy/mstump.py index c4b7ed2c9..35d58b130 100644 --- a/stumpy/mstump.py +++ b/stumpy/mstump.py @@ -811,7 +811,7 @@ def _get_multi_QT(start, T, m): # "(i8, i8, i8, f8[:, :], f8[:, :], i8, i8, f8[:, :], f8[:, :], f8[:, :]," # "f8[:, :], f8[:, :], f8[:, :], f8[:, :])", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_multi_D( d, diff --git a/stumpy/scraamp.py b/stumpy/scraamp.py index 56d83f6b6..682a83405 100644 --- a/stumpy/scraamp.py +++ b/stumpy/scraamp.py @@ -83,7 +83,7 @@ def _preprocess_prescraamp(T_A, m, T_B=None, s=None): return (T_A, T_B, T_A_subseq_isfinite, T_B_subseq_isfinite, indices, s, excl_zone) -@njit(fastmath=config.STUMPY_FASTMATH_TRUE) +@njit(fastmath=config.STUMPY_FASTMATH_FLAGS) def _compute_PI( T_A, T_B, @@ -286,7 +286,7 @@ def _compute_PI( # "(f8[:], f8[:], i8, b1[:], b1[:], f8, i8, i8, f8[:], f8[:]," # "i8[:], optional(i8))", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _prescraamp( T_A, diff --git a/stumpy/scrump.py b/stumpy/scrump.py index dd5617480..b9894770d 100644 --- a/stumpy/scrump.py +++ b/stumpy/scrump.py @@ -133,7 +133,7 @@ def _preprocess_prescrump( ) -@njit(fastmath=config.STUMPY_FASTMATH_TRUE) +@njit(fastmath=config.STUMPY_FASTMATH_FLAGS) def _compute_PI( T_A, T_B, @@ -384,7 +384,7 @@ def _compute_PI( # "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], i8, i8, f8[:], f8[:]," # "i8[:], optional(i8))", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _prescrump( T_A, diff --git a/stumpy/stump.py b/stumpy/stump.py index 18409c6e1..10e4d0e3f 100644 --- a/stumpy/stump.py +++ b/stumpy/stump.py @@ -15,7 +15,7 @@ # "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], f8[:]," # "b1[:], b1[:], b1[:], b1[:], i8[:], i8, i8, i8, f8[:, :, :], f8[:, :]," # "f8[:, :], i8[:, :, :], i8[:, :], i8[:, :], b1)", - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _compute_diagonal( T_A, @@ -247,7 +247,7 @@ def _compute_diagonal( # "(f8[:], f8[:], i8, f8[:], f8[:], f8[:], f8[:], f8[:], f8[:], b1[:], b1[:]," # "b1[:], b1[:], i8[:], b1, i8)", parallel=True, - fastmath=config.STUMPY_FASTMATH_TRUE, + fastmath=config.STUMPY_FASTMATH_FLAGS, ) def _stump( T_A,