Skip to content

Commit 08bdec6

Browse files
Chilleezou3519
authored andcommitted
[functorch] updated op analysis script
1 parent eb292d7 commit 08bdec6

File tree

2 files changed

+647
-69
lines changed

2 files changed

+647
-69
lines changed

functorch/op_analysis/gen_data.py

Lines changed: 24 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,15 @@
11
import yaml
22
import csv
33
import torch
4-
import sys
5-
import os
64
from collections import defaultdict
75

86

9-
class CapturedOutput(object):
10-
"""
11-
Class used to grab standard output.
12-
We need this instead of contextlib.redirect_stdout() if the printed text
13-
that we want to capture comes from C++.
14-
The result is stored in capturedtext.
15-
Pulled partially from https://www.py4u.net/discuss/66399.
16-
"""
17-
escape_char = "\b"
18-
19-
def __init__(self):
20-
self.origstream = sys.stdout
21-
self.origstreamfd = self.origstream.fileno()
22-
self.capturedtext = ""
23-
# Create a pipe so the stream can be captured:
24-
self.pipe_out, self.pipe_in = os.pipe()
25-
26-
def __enter__(self):
27-
self.capturedtext = ""
28-
# Save a copy of the stream:
29-
self.streamfd = os.dup(self.origstreamfd)
30-
# Replace the original stream with our write pipe:
31-
os.dup2(self.pipe_in, self.origstreamfd)
32-
return self
33-
34-
def __exit__(self, type, value, traceback):
35-
# Print the escape character to make the readOutput method stop:
36-
self.origstream.write(self.escape_char)
37-
# Flush the stream to make sure all our data goes in before
38-
# the escape character:
39-
self.origstream.flush()
40-
self.readOutput()
41-
# Close the pipe:
42-
os.close(self.pipe_in)
43-
os.close(self.pipe_out)
44-
# Restore the original stream:
45-
os.dup2(self.streamfd, self.origstreamfd)
46-
# Close the duplicate stream:
47-
os.close(self.streamfd)
48-
49-
def readOutput(self):
50-
"""
51-
Read the stream data (one byte at a time)
52-
and save the text in `capturedtext`.
53-
"""
54-
while True:
55-
char = os.read(self.pipe_out, 1)
56-
if not char:
57-
break
58-
char = char.decode("utf-8")
59-
if self.escape_char in char:
60-
break
61-
self.capturedtext += char
62-
63-
647
def get_ops_for_key(key):
65-
all_out = CapturedOutput()
66-
with all_out:
67-
if key is None:
68-
torch._C._dispatch_print_registrations_for_dispatch_key()
69-
else:
70-
torch._C._dispatch_print_registrations_for_dispatch_key(key)
71-
72-
ops = all_out.capturedtext.split('\n')
8+
# Needs modified PyTorch C++ code to work
9+
if key is None:
10+
ops = torch._C._dispatch_get_registrations_for_dispatch_key()
11+
else:
12+
ops = torch._C._dispatch_get_registrations_for_dispatch_key(key)
7313
cleaned_ops = []
7414
for i in ops:
7515
if 'aten::' not in i:
@@ -161,7 +101,6 @@ def annotate_ops(ops, is_unique):
161101

162102
annotate_ops(ops, is_unique=False)
163103
with open(f"{analysis_name}", 'w') as f:
164-
# import pdb; pdb.set_trace()
165104
for op in ops:
166105
info = [
167106
op['full_name'], op['meta'], not (op['full_name'] in noncomposite_ops)
@@ -186,6 +125,11 @@ def remove_suffix(input_string, suffix):
186125
return input_string[:-len(suffix)]
187126
return input_string
188127

128+
def remove_prefix(input_string, prefix):
129+
if prefix and input_string.startswith(prefix):
130+
return input_string[len(prefix):]
131+
return input_string
132+
189133

190134
if True:
191135
with open('run_ops.txt', 'r') as f:
@@ -194,9 +138,20 @@ def remove_suffix(input_string, suffix):
194138
opinfo_counts = [i.strip() for i in f.readlines()]
195139
opinfo_counts = defaultdict(int, {k: v for k, v in zip(opinfo_ops, opinfo_counts)})
196140

197-
def count_fn(x):
198-
return opinfo_counts[x['full_name']]
141+
def count_fn(x):
142+
return opinfo_counts[x['full_name']]
199143

200144
with open('run_decompositions.txt', 'r') as f:
201145
decomposed_ops = [remove_suffix(i.strip(), '.default') for i in f.readlines()]
202-
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn], 'decompositions.txt')
146+
147+
with open('public_api', 'r') as f:
148+
ref_api = [i.strip() for i in f.readlines()]
149+
150+
def has_ref_impl(x):
151+
name = x['name']
152+
for prefix in ["linalg_", "special_"]:
153+
name = remove_prefix(name, prefix)
154+
prefixes = ['nn.functional', 'fft', 'special', 'linalg']
155+
return any(f"{prefix}.{name}" in ref_api for prefix in prefixes) or name in ref_api
156+
157+
gen_data([full_name_check(opinfo_ops), full_name_check(decomposed_ops), count_fn, has_ref_impl], 'decompositions.txt')

0 commit comments

Comments
 (0)