Skip to content

Commit 74e9605

Browse files
Caner Gocmenfacebook-github-bot
authored andcommitted
stop OSS planner when time limits are hit
Summary: Adding a `timeout_seconds` argument to the OSS planner to return the best plan within a predefined time budget. Reviewed By: iamzainhuda Differential Revision: D70015750 fbshipit-source-id: 428df430afba3335acb87cf8a8dabcaa5284fc17
1 parent 6aaf1fa commit 74e9605

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

torchrec/distributed/planner/planners.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
# pyre-strict
99

1010
import copy
11+
import logging
12+
import time
1113
from functools import reduce
1214
from time import perf_counter
1315
from typing import Callable, cast, Dict, List, Optional, Tuple, Union
@@ -66,6 +68,8 @@
6668
)
6769
from torchrec.distributed.utils import none_throws
6870

71+
logger: logging.Logger = logging.getLogger(__name__)
72+
6973

7074
def to_sharding_plan(
7175
sharding_options: List[ShardingOption],
@@ -181,6 +185,7 @@ def __init__(
181185
callbacks: Optional[
182186
List[Callable[[List[ShardingOption]], List[ShardingOption]]]
183187
] = None,
188+
timeout_seconds: Optional[int] = None,
184189
) -> None:
185190
if topology is None:
186191
topology = Topology(
@@ -235,6 +240,9 @@ def __init__(
235240
self._callbacks: List[
236241
Callable[[List[ShardingOption]], List[ShardingOption]]
237242
] = ([] if callbacks is None else callbacks)
243+
if timeout_seconds is not None:
244+
assert timeout_seconds > 0, "Timeout must be positive"
245+
self._timeout_seconds = timeout_seconds
238246

239247
def collective_plan(
240248
self,
@@ -320,10 +328,19 @@ def plan(
320328
for proposer in self._proposers:
321329
proposer.load(search_space=search_space, enumerator=self._enumerator)
322330

331+
start = time.time()
323332
for proposer in self._proposers:
324333
proposal = proposer.propose()
325334

326335
while proposal:
336+
end = time.time()
337+
elapsed = end - start
338+
if self._timeout_seconds:
339+
if elapsed > self._timeout_seconds:
340+
logger.info(
341+
f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s"
342+
)
343+
break
327344
proposal_key = tuple(sorted(map(hash, proposal)))
328345
if proposal_key in proposal_cache:
329346
partitionable, plan, perf_rating = proposal_cache[proposal_key]

0 commit comments

Comments
 (0)