1+ import ast
2+
3+ import pathlib
4+
5+ from stumpy import cache
6+
7+ def check_fastmath (decorator ):
8+ """
9+ For the given `decorator` node with type `ast.Call`,
10+ return the value of the `fastmath` argument if it exists.
11+ Otherwise, return `None`.
12+ """
13+ fastmath_value = None
14+ for n in ast .iter_child_nodes (decorator ):
15+ if isinstance (n , ast .keyword ) and n .arg == "fastmath" :
16+ if isinstance (n .value , ast .Constant ):
17+ fastmath_value = n .value .value
18+ elif isinstance (n .value , ast .Set ):
19+ fastmath_value = set (item .value for item in n .value .elts )
20+ else :
21+ pass
22+ break
23+
24+ return fastmath_value
25+
26+
27+ def check_njit (fd ):
28+ """
29+ For the given `fd` node with type `ast.FunctionDef`,
30+ return the node of the `njit` decorator if it exists.
31+ Otherwise, return `None`.
32+ """
33+ decorator_node = None
34+ for decorator in fd .decorator_list :
35+ if not isinstance (decorator , ast .Call ):
36+ continue
37+
38+ obj = decorator .func
39+ if isinstance (obj , ast .Attribute ):
40+ name = obj .attr
41+ elif isinstance (obj , ast .Subscript ):
42+ name = obj .value .id
43+ elif isinstance (obj , ast .Name ):
44+ name = obj .id
45+ else :
46+ msg = f"The type { type (obj )} is not supported."
47+ raise ValueError (msg )
48+
49+ if name == "njit" :
50+ decorator_node = decorator
51+ break
52+
53+ return decorator_node
54+
55+
56+ def check_functions (filepath ):
57+ """
58+ For the given `filepath`, return the function names,
59+ whether the function is decorated with `@njit`,
60+ and the value of the `fastmath` argument if it exists
61+
62+ Parameters
63+ ----------
64+ filepath : str
65+ The path to the file
66+
67+ Returns
68+ -------
69+ func_names : list
70+ List of function names
71+
72+ is_njit : list
73+ List of boolean values indicating whether the function is decorated with `@njit`
74+
75+ fastmath_value : list
76+ List of values of the `fastmath` argument if it exists
77+ """
78+ file_contents = ""
79+ with open (filepath , encoding = "utf8" ) as f :
80+ file_contents = f .read ()
81+ module = ast .parse (file_contents )
82+
83+ function_definitions = [
84+ node for node in module .body if isinstance (node , ast .FunctionDef )
85+ ]
86+
87+ func_names = [fd .name for fd in function_definitions ]
88+
89+ njit_nodes = [check_njit (fd ) for fd in function_definitions ]
90+ is_njit = [node is not None for node in njit_nodes ]
91+
92+ fastmath_values = [None ] * len (njit_nodes )
93+ for i , node in enumerate (njit_nodes ):
94+ if node is not None :
95+ fastmath_values [i ] = check_fastmath (node )
96+
97+ return func_names , is_njit , fastmath_values
98+
99+
100+ def _get_callees (node , all_functions ):
101+ for n in ast .iter_child_nodes (node ):
102+ if isinstance (n , ast .Call ):
103+ obj = n .func
104+ if isinstance (obj , ast .Attribute ):
105+ name = obj .attr
106+ elif isinstance (obj , ast .Subscript ):
107+ name = obj .value .id
108+ elif isinstance (obj , ast .Name ):
109+ name = obj .id
110+ else :
111+ msg = f"The type { type (obj )} is not supported"
112+ raise ValueError (msg )
113+
114+ all_functions .append (name )
115+
116+ _get_callees (n , all_functions )
117+
118+
119+ def get_all_callees (fd ):
120+ """
121+ For a given node of type ast.FunctionDef, visit all of its child nodes,
122+ and return a list of all of its callees
123+ """
124+ all_functions = []
125+ _get_callees (fd , all_functions )
126+
127+ return all_functions
128+
129+
130+ def check_callees (filepath ):
131+ """
132+ For the given `filepath`, return a dictionary with the key
133+ being the function name and the value being a set of function names
134+ that are called by the function
135+ """
136+ file_contents = ""
137+ with open (filepath , encoding = "utf8" ) as f :
138+ file_contents = f .read ()
139+ module = ast .parse (file_contents )
140+
141+ function_definitions = [
142+ node for node in module .body if isinstance (node , ast .FunctionDef )
143+ ]
144+
145+ callees = {}
146+ for fd in function_definitions :
147+ callees [fd .name ] = set (get_all_callees (fd ))
148+
149+ return callees
150+
151+
152+ stumpy_path = pathlib .Path (__file__ ).parent # / "stumpy"
153+ filepaths = sorted (f for f in pathlib .Path (stumpy_path ).iterdir () if f .is_file ())
154+
155+ all_functions = {}
156+
157+ ignore = ["__init__.py" , "__pycache__" ]
158+ for filepath in filepaths :
159+ file_name = filepath .name
160+ if file_name not in ignore and str (filepath ).endswith (".py" ):
161+ prefix = file_name .replace (".py" , "" )
162+
163+ func_names , is_njit , fastmath_values = check_functions (filepath )
164+ func_names = [f"{ prefix } .{ fn } " for fn in func_names ]
165+
166+ all_functions [file_name ] = {
167+ "func_names" : func_names ,
168+ "is_njit" : is_njit ,
169+ "fastmath_values" : fastmath_values ,
170+ }
171+
172+ all_stumpy_functions = set ()
173+ for file_name , file_functions_metadata in all_functions .items ():
174+ all_stumpy_functions .update (file_functions_metadata ["func_names" ])
175+
176+ all_stumpy_functions = list (all_stumpy_functions )
177+ all_stumpy_functions_no_prefix = [f .split ("." )[- 1 ] for f in all_stumpy_functions ]
178+
179+
180+ # output 1: func_metadata
181+ func_metadata = {}
182+ for file_name , file_functions_metadata in all_functions .items ():
183+ for i , f in enumerate (file_functions_metadata ["func_names" ]):
184+ is_njit = file_functions_metadata ["is_njit" ][i ]
185+ fastmath_value = file_functions_metadata ["fastmath_values" ][i ]
186+ func_metadata [f ] = [is_njit , fastmath_value ]
187+
188+
189+ # output 2: func_callers
190+ func_callers = {}
191+ for f in func_metadata .keys ():
192+ func_callers [f ] = []
193+
194+ for filepath in filepaths :
195+ file_name = filepath .name
196+ if file_name in ignore or not str (filepath ).endswith (".py" ):
197+ continue
198+
199+ prefix = file_name .replace (".py" , "" )
200+ callees = check_callees (filepath )
201+
202+ current_callers = set (callees .keys ())
203+ for caller , callee_set in callees .items ():
204+ s = list (callee_set .intersection (all_stumpy_functions_no_prefix ))
205+ if len (s ) == 0 :
206+ continue
207+
208+ for c in s :
209+ if c in current_callers :
210+ c_name = prefix + "." + c
211+ else :
212+ idx = all_stumpy_functions_no_prefix .index (c )
213+ c_name = all_stumpy_functions [idx ]
214+
215+ func_callers [c_name ].append (f"{ prefix } .{ caller } " )
216+
217+
218+ for f , callers in func_callers .items ():
219+ func_callers [f ] = list (set (callers ))
220+
221+
222+
223+ for modue_name , func_name in cache .get_njit_funcs ():
224+ f = f"{ modue_name } .{ func_name } "
225+ print (f , func_callers [f ])
0 commit comments