Skip to content

Commit fa705f7

Browse files
nmacchionipytorchmergebot
authored andcommitted
[BE] minor refactor + some comments on behavior (pytorch#154695)
Pull Request resolved: pytorch#154695 Approved by: https://github.com/masnesral, https://github.com/eellison
1 parent 9e88d6c commit fa705f7

File tree

1 file changed

+27
-7
lines changed

1 file changed

+27
-7
lines changed

torch/_inductor/select_algorithm.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# mypy: allow-untyped-defs
2-
import builtins
32
import contextlib
43
import dataclasses
54
import functools
@@ -2334,13 +2333,34 @@ def get_timings():
23342333
)
23352334

23362335
timings = do_autotuning(choices, precompile_fn)
2337-
if timings == {} or choices[0] not in timings:
2338-
return choices[0].output_node()
23392336

2340-
selected_key = builtins.min(timings, key=timings.__getitem__)
2341-
selected_choice = selected_key.output_node()
2342-
log.debug("selected choice: %s", str(selected_choice))
2343-
return selected_choice
2337+
# if timings is empty, we really have no choice but to return a semi-random
2338+
# choice. returning the first `ExternKernelCaller` is probably the safest bet
2339+
# in this case, since it will generally be the ATen kernel. if there are no
2340+
# `ExternKernelCaller`s to return, then returning the 0th kernel is our next
2341+
# best option (ideally we'd fail whenever there is no ATen kernel to fallback
2342+
# to, but that's not trivial to figure out)
2343+
if timings == {}:
2344+
for choice in choices:
2345+
if isinstance(choice, ExternKernelCaller):
2346+
node = choice.output_node()
2347+
log.debug(
2348+
"Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s",
2349+
node,
2350+
)
2351+
return node
2352+
node = choices[0].output_node()
2353+
log.debug(
2354+
"Autotuning returned empty timings, falling back to first choice: %s",
2355+
node,
2356+
)
2357+
return node
2358+
2359+
# if we got any timings at all, pick the best of those
2360+
choice = min(timings, key=timings.__getitem__)
2361+
node = choice.output_node()
2362+
log.debug("Autotuning selected choice: %s", node)
2363+
return node
23442364

23452365
def make_precompile_fn(
23462366
self,

0 commit comments

Comments
 (0)