1
1
import yaml
2
2
import csv
3
3
import torch
4
- import sys
5
- import os
6
4
from collections import defaultdict
7
5
8
6
9
- class CapturedOutput (object ):
10
- """
11
- Class used to grab standard output.
12
- We need this instead of contextlib.redirect_stdout() if the printed text
13
- that we want to capture comes from C++.
14
- The result is stored in capturedtext.
15
- Pulled partially from https://www.py4u.net/discuss/66399.
16
- """
17
- escape_char = "\b "
18
-
19
- def __init__ (self ):
20
- self .origstream = sys .stdout
21
- self .origstreamfd = self .origstream .fileno ()
22
- self .capturedtext = ""
23
- # Create a pipe so the stream can be captured:
24
- self .pipe_out , self .pipe_in = os .pipe ()
25
-
26
- def __enter__ (self ):
27
- self .capturedtext = ""
28
- # Save a copy of the stream:
29
- self .streamfd = os .dup (self .origstreamfd )
30
- # Replace the original stream with our write pipe:
31
- os .dup2 (self .pipe_in , self .origstreamfd )
32
- return self
33
-
34
- def __exit__ (self , type , value , traceback ):
35
- # Print the escape character to make the readOutput method stop:
36
- self .origstream .write (self .escape_char )
37
- # Flush the stream to make sure all our data goes in before
38
- # the escape character:
39
- self .origstream .flush ()
40
- self .readOutput ()
41
- # Close the pipe:
42
- os .close (self .pipe_in )
43
- os .close (self .pipe_out )
44
- # Restore the original stream:
45
- os .dup2 (self .streamfd , self .origstreamfd )
46
- # Close the duplicate stream:
47
- os .close (self .streamfd )
48
-
49
- def readOutput (self ):
50
- """
51
- Read the stream data (one byte at a time)
52
- and save the text in `capturedtext`.
53
- """
54
- while True :
55
- char = os .read (self .pipe_out , 1 )
56
- if not char :
57
- break
58
- char = char .decode ("utf-8" )
59
- if self .escape_char in char :
60
- break
61
- self .capturedtext += char
62
-
63
-
64
7
def get_ops_for_key (key ):
65
- all_out = CapturedOutput ()
66
- with all_out :
67
- if key is None :
68
- torch ._C ._dispatch_print_registrations_for_dispatch_key ()
69
- else :
70
- torch ._C ._dispatch_print_registrations_for_dispatch_key (key )
71
-
72
- ops = all_out .capturedtext .split ('\n ' )
8
+ # Needs modified PyTorch C++ code to work
9
+ if key is None :
10
+ ops = torch ._C ._dispatch_get_registrations_for_dispatch_key ()
11
+ else :
12
+ ops = torch ._C ._dispatch_get_registrations_for_dispatch_key (key )
73
13
cleaned_ops = []
74
14
for i in ops :
75
15
if 'aten::' not in i :
@@ -161,7 +101,6 @@ def annotate_ops(ops, is_unique):
161
101
162
102
annotate_ops (ops , is_unique = False )
163
103
with open (f"{ analysis_name } " , 'w' ) as f :
164
- # import pdb; pdb.set_trace()
165
104
for op in ops :
166
105
info = [
167
106
op ['full_name' ], op ['meta' ], not (op ['full_name' ] in noncomposite_ops )
@@ -186,6 +125,11 @@ def remove_suffix(input_string, suffix):
186
125
return input_string [:- len (suffix )]
187
126
return input_string
188
127
128
+ def remove_prefix (input_string , prefix ):
129
+ if prefix and input_string .startswith (prefix ):
130
+ return input_string [len (prefix ):]
131
+ return input_string
132
+
189
133
190
134
if True :
191
135
with open ('run_ops.txt' , 'r' ) as f :
@@ -194,9 +138,20 @@ def remove_suffix(input_string, suffix):
194
138
opinfo_counts = [i .strip () for i in f .readlines ()]
195
139
opinfo_counts = defaultdict (int , {k : v for k , v in zip (opinfo_ops , opinfo_counts )})
196
140
197
- def count_fn (x ):
198
- return opinfo_counts [x ['full_name' ]]
141
+ def count_fn (x ):
142
+ return opinfo_counts [x ['full_name' ]]
199
143
200
144
with open ('run_decompositions.txt' , 'r' ) as f :
201
145
decomposed_ops = [remove_suffix (i .strip (), '.default' ) for i in f .readlines ()]
202
- gen_data ([full_name_check (opinfo_ops ), full_name_check (decomposed_ops ), count_fn ], 'decompositions.txt' )
146
+
147
+ with open ('public_api' , 'r' ) as f :
148
+ ref_api = [i .strip () for i in f .readlines ()]
149
+
150
+ def has_ref_impl (x ):
151
+ name = x ['name' ]
152
+ for prefix in ["linalg_" , "special_" ]:
153
+ name = remove_prefix (name , prefix )
154
+ prefixes = ['nn.functional' , 'fft' , 'special' , 'linalg' ]
155
+ return any (f"{ prefix } .{ name } " in ref_api for prefix in prefixes ) or name in ref_api
156
+
157
+ gen_data ([full_name_check (opinfo_ops ), full_name_check (decomposed_ops ), count_fn , has_ref_impl ], 'decompositions.txt' )
0 commit comments