Skip to content

Commit 2e56081

Browse files
committed
[functorch] update discover_coverage
1 parent a1d8959 commit 2e56081

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

functorch/test/discover_coverage.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def safe_append(dct, key, val):
183183
}
184184

185185

186-
def get_top_ops(torch_threshold, nn_fn_threshold):
186+
def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False):
187187
denylist = set({
188188
# These are either not real "operators", factory functions
189189
# that trivially work, or not-documented ops.
@@ -228,12 +228,17 @@ def get_top_ops(torch_threshold, nn_fn_threshold):
228228
'fft', # is namespace
229229
})
230230

231-
torch_ops = [op[0] for op in top_ops.top_torch]
232-
nn_fn_ops = [op[0] for op in top_ops.get_nn_functional_top_list()]
233-
torch_ops = [op for op in torch_ops if op not in denylist]
234-
nn_fn_ops = [op for op in nn_fn_ops if op not in denylist]
231+
torch_ops = [op for op in top_ops.top_torch]
232+
nn_fn_ops = [op for op in top_ops.get_nn_functional_top_list()]
233+
torch_ops = [op for op in torch_ops if op[0] not in denylist]
234+
nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist]
235235

236236
ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold]
237+
238+
# Now, sort by priority
239+
ops.sort(reverse=True, key=lambda op: op[1])
240+
if not with_counts:
241+
ops = [op[0] for op in ops]
237242
return ops
238243

239244

@@ -341,8 +346,6 @@ def get_skipped_or_xfailed_ops_for(test_name):
341346
return result
342347

343348

344-
# import pdb; pdb.set_trace()
345-
346349
def get_statuses(for_subset=None, invert=False):
347350
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
348351
if for_subset is not None:
@@ -886,3 +889,8 @@ def summary(self):
886889
# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
887890
# pprint.pprint(result)
888891
# print(opset.summary())
892+
893+
# Print list of everything in order
894+
# all_ops = get_top_ops(999999, 999999, with_counts=True)
895+
# for op, count in all_ops:
896+
# print(f'{op}, {count}')

0 commit comments

Comments
 (0)