@@ -183,7 +183,7 @@ def safe_append(dct, key, val):
183
183
}
184
184
185
185
186
- def get_top_ops (torch_threshold , nn_fn_threshold ):
186
+ def get_top_ops (torch_threshold , nn_fn_threshold , with_counts = False ):
187
187
denylist = set ({
188
188
# These are either not real "operators", factory functions
189
189
# that trivially work, or not-documented ops.
@@ -228,12 +228,17 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
228
228
'fft' , # is namespace
229
229
})
230
230
231
- torch_ops = [op [ 0 ] for op in top_ops .top_torch ]
232
- nn_fn_ops = [op [ 0 ] for op in top_ops .get_nn_functional_top_list ()]
233
- torch_ops = [op for op in torch_ops if op not in denylist ]
234
- nn_fn_ops = [op for op in nn_fn_ops if op not in denylist ]
231
+ torch_ops = [op for op in top_ops .top_torch ]
232
+ nn_fn_ops = [op for op in top_ops .get_nn_functional_top_list ()]
233
+ torch_ops = [op for op in torch_ops if op [ 0 ] not in denylist ]
234
+ nn_fn_ops = [op for op in nn_fn_ops if op [ 0 ] not in denylist ]
235
235
236
236
ops = torch_ops [:torch_threshold ] + nn_fn_ops [:nn_fn_threshold ]
237
+
238
+ # Now, sort by priority
239
+ ops .sort (reverse = True , key = lambda op : op [1 ])
240
+ if not with_counts :
241
+ ops = [op [0 ] for op in ops ]
237
242
return ops
238
243
239
244
@@ -341,8 +346,6 @@ def get_skipped_or_xfailed_ops_for(test_name):
341
346
return result
342
347
343
348
344
- # import pdb; pdb.set_trace()
345
-
346
349
def get_statuses (for_subset = None , invert = False ):
347
350
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about ()
348
351
if for_subset is not None :
@@ -886,3 +889,8 @@ def summary(self):
886
889
# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
887
890
# pprint.pprint(result)
888
891
# print(opset.summary())
892
+
893
+ # Print list of everything in order
894
+ # all_ops = get_top_ops(999999, 999999, with_counts=True)
895
+ # for op, count in all_ops:
896
+ # print(f'{op}, {count}')
0 commit comments