Skip to content

Commit ad2c4f2

Browse files
committed
temp commit
1 parent c24f0a9 commit ad2c4f2

File tree

1 file changed

+225
-0
lines changed

1 file changed

+225
-0
lines changed

stumpy/utils.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
import ast
2+
3+
import pathlib
4+
5+
from stumpy import cache
6+
7+
def check_fastmath(decorator):
8+
"""
9+
For the given `decorator` node with type `ast.Call`,
10+
return the value of the `fastmath` argument if it exists.
11+
Otherwise, return `None`.
12+
"""
13+
fastmath_value = None
14+
for n in ast.iter_child_nodes(decorator):
15+
if isinstance(n, ast.keyword) and n.arg == "fastmath":
16+
if isinstance(n.value, ast.Constant):
17+
fastmath_value = n.value.value
18+
elif isinstance(n.value, ast.Set):
19+
fastmath_value = set(item.value for item in n.value.elts)
20+
else:
21+
pass
22+
break
23+
24+
return fastmath_value
25+
26+
27+
def check_njit(fd):
28+
"""
29+
For the given `fd` node with type `ast.FunctionDef`,
30+
return the node of the `njit` decorator if it exists.
31+
Otherwise, return `None`.
32+
"""
33+
decorator_node = None
34+
for decorator in fd.decorator_list:
35+
if not isinstance(decorator, ast.Call):
36+
continue
37+
38+
obj = decorator.func
39+
if isinstance(obj, ast.Attribute):
40+
name = obj.attr
41+
elif isinstance(obj, ast.Subscript):
42+
name = obj.value.id
43+
elif isinstance(obj, ast.Name):
44+
name = obj.id
45+
else:
46+
msg = f"The type {type(obj)} is not supported."
47+
raise ValueError(msg)
48+
49+
if name == "njit":
50+
decorator_node = decorator
51+
break
52+
53+
return decorator_node
54+
55+
56+
def check_functions(filepath):
57+
"""
58+
For the given `filepath`, return the function names,
59+
whether the function is decorated with `@njit`,
60+
and the value of the `fastmath` argument if it exists
61+
62+
Parameters
63+
----------
64+
filepath : str
65+
The path to the file
66+
67+
Returns
68+
-------
69+
func_names : list
70+
List of function names
71+
72+
is_njit : list
73+
List of boolean values indicating whether the function is decorated with `@njit`
74+
75+
fastmath_value : list
76+
List of values of the `fastmath` argument if it exists
77+
"""
78+
file_contents = ""
79+
with open(filepath, encoding="utf8") as f:
80+
file_contents = f.read()
81+
module = ast.parse(file_contents)
82+
83+
function_definitions = [
84+
node for node in module.body if isinstance(node, ast.FunctionDef)
85+
]
86+
87+
func_names = [fd.name for fd in function_definitions]
88+
89+
njit_nodes = [check_njit(fd) for fd in function_definitions]
90+
is_njit = [node is not None for node in njit_nodes]
91+
92+
fastmath_values = [None] * len(njit_nodes)
93+
for i, node in enumerate(njit_nodes):
94+
if node is not None:
95+
fastmath_values[i] = check_fastmath(node)
96+
97+
return func_names, is_njit, fastmath_values
98+
99+
100+
def _get_callees(node, all_functions):
101+
for n in ast.iter_child_nodes(node):
102+
if isinstance(n, ast.Call):
103+
obj = n.func
104+
if isinstance(obj, ast.Attribute):
105+
name = obj.attr
106+
elif isinstance(obj, ast.Subscript):
107+
name = obj.value.id
108+
elif isinstance(obj, ast.Name):
109+
name = obj.id
110+
else:
111+
msg = f"The type {type(obj)} is not supported"
112+
raise ValueError(msg)
113+
114+
all_functions.append(name)
115+
116+
_get_callees(n, all_functions)
117+
118+
119+
def get_all_callees(fd):
120+
"""
121+
For a given node of type ast.FunctionDef, visit all of its child nodes,
122+
and return a list of all of its callees
123+
"""
124+
all_functions = []
125+
_get_callees(fd, all_functions)
126+
127+
return all_functions
128+
129+
130+
def check_callees(filepath):
131+
"""
132+
For the given `filepath`, return a dictionary with the key
133+
being the function name and the value being a set of function names
134+
that are called by the function
135+
"""
136+
file_contents = ""
137+
with open(filepath, encoding="utf8") as f:
138+
file_contents = f.read()
139+
module = ast.parse(file_contents)
140+
141+
function_definitions = [
142+
node for node in module.body if isinstance(node, ast.FunctionDef)
143+
]
144+
145+
callees = {}
146+
for fd in function_definitions:
147+
callees[fd.name] = set(get_all_callees(fd))
148+
149+
return callees
150+
151+
152+
stumpy_path = pathlib.Path(__file__).parent # / "stumpy"
153+
filepaths = sorted(f for f in pathlib.Path(stumpy_path).iterdir() if f.is_file())
154+
155+
all_functions = {}
156+
157+
ignore = ["__init__.py", "__pycache__"]
158+
for filepath in filepaths:
159+
file_name = filepath.name
160+
if file_name not in ignore and str(filepath).endswith(".py"):
161+
prefix = file_name.replace(".py", "")
162+
163+
func_names, is_njit, fastmath_values = check_functions(filepath)
164+
func_names = [f"{prefix}.{fn}" for fn in func_names]
165+
166+
all_functions[file_name] = {
167+
"func_names": func_names,
168+
"is_njit": is_njit,
169+
"fastmath_values": fastmath_values,
170+
}
171+
172+
all_stumpy_functions = set()
173+
for file_name, file_functions_metadata in all_functions.items():
174+
all_stumpy_functions.update(file_functions_metadata["func_names"])
175+
176+
all_stumpy_functions = list(all_stumpy_functions)
177+
all_stumpy_functions_no_prefix = [f.split(".")[-1] for f in all_stumpy_functions]
178+
179+
180+
# output 1: func_metadata
181+
func_metadata = {}
182+
for file_name, file_functions_metadata in all_functions.items():
183+
for i, f in enumerate(file_functions_metadata["func_names"]):
184+
is_njit = file_functions_metadata["is_njit"][i]
185+
fastmath_value = file_functions_metadata["fastmath_values"][i]
186+
func_metadata[f] = [is_njit, fastmath_value]
187+
188+
189+
# output 2: func_callers
190+
func_callers = {}
191+
for f in func_metadata.keys():
192+
func_callers[f] = []
193+
194+
for filepath in filepaths:
195+
file_name = filepath.name
196+
if file_name in ignore or not str(filepath).endswith(".py"):
197+
continue
198+
199+
prefix = file_name.replace(".py", "")
200+
callees = check_callees(filepath)
201+
202+
current_callers = set(callees.keys())
203+
for caller, callee_set in callees.items():
204+
s = list(callee_set.intersection(all_stumpy_functions_no_prefix))
205+
if len(s) == 0:
206+
continue
207+
208+
for c in s:
209+
if c in current_callers:
210+
c_name = prefix + "." + c
211+
else:
212+
idx = all_stumpy_functions_no_prefix.index(c)
213+
c_name = all_stumpy_functions[idx]
214+
215+
func_callers[c_name].append(f"{prefix}.{caller}")
216+
217+
218+
for f, callers in func_callers.items():
219+
func_callers[f] = list(set(callers))
220+
221+
222+
223+
for modue_name, func_name in cache.get_njit_funcs():
224+
f = f"{modue_name}.{func_name}"
225+
print(f, func_callers[f])

0 commit comments

Comments
 (0)