|
8 | 8 | # pyre-strict
|
9 | 9 |
|
10 | 10 | import copy
|
| 11 | +import logging |
| 12 | +import time |
11 | 13 | from functools import reduce
|
12 | 14 | from time import perf_counter
|
13 | 15 | from typing import Callable, cast, Dict, List, Optional, Tuple, Union
|
|
66 | 68 | )
|
67 | 69 | from torchrec.distributed.utils import none_throws
|
68 | 70 |
|
| 71 | +logger: logging.Logger = logging.getLogger(__name__) |
| 72 | + |
69 | 73 |
|
70 | 74 | def to_sharding_plan(
|
71 | 75 | sharding_options: List[ShardingOption],
|
@@ -181,6 +185,7 @@ def __init__(
|
181 | 185 | callbacks: Optional[
|
182 | 186 | List[Callable[[List[ShardingOption]], List[ShardingOption]]]
|
183 | 187 | ] = None,
|
| 188 | + timeout_seconds: Optional[int] = None, |
184 | 189 | ) -> None:
|
185 | 190 | if topology is None:
|
186 | 191 | topology = Topology(
|
@@ -235,6 +240,9 @@ def __init__(
|
235 | 240 | self._callbacks: List[
|
236 | 241 | Callable[[List[ShardingOption]], List[ShardingOption]]
|
237 | 242 | ] = ([] if callbacks is None else callbacks)
|
| 243 | + if timeout_seconds is not None: |
| 244 | + assert timeout_seconds > 0, "Timeout must be positive" |
| 245 | + self._timeout_seconds = timeout_seconds |
238 | 246 |
|
239 | 247 | def collective_plan(
|
240 | 248 | self,
|
@@ -320,10 +328,19 @@ def plan(
|
320 | 328 | for proposer in self._proposers:
|
321 | 329 | proposer.load(search_space=search_space, enumerator=self._enumerator)
|
322 | 330 |
|
| 331 | + start = time.time() |
323 | 332 | for proposer in self._proposers:
|
324 | 333 | proposal = proposer.propose()
|
325 | 334 |
|
326 | 335 | while proposal:
|
| 336 | + end = time.time() |
| 337 | + elapsed = end - start |
| 338 | + if self._timeout_seconds: |
| 339 | + if elapsed > self._timeout_seconds: |
| 340 | + logger.info( |
| 341 | + f"Exceeded time limit of {self._timeout_seconds}s. Took {elapsed}s" |
| 342 | + ) |
| 343 | + break |
327 | 344 | proposal_key = tuple(sorted(map(hash, proposal)))
|
328 | 345 | if proposal_key in proposal_cache:
|
329 | 346 | partitionable, plan, perf_rating = proposal_cache[proposal_key]
|
|
0 commit comments