Skip to content

Commit 5534617

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
check set_epoch exists on sampler (#789)
Summary: Pull Request resolved: #789 Reviewed By: diego-urgell Differential Revision: D56262859 fbshipit-source-id: c467e555e4529e25c31ba0ba39a95b505337e1c6
1 parent c444003 commit 5534617

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchtnt/framework/_loop_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# pyre-strict
88

99
import logging
10-
from typing import Dict, Iterable, Optional, TypeVar
10+
from typing import Dict, Iterable, Optional, Protocol, runtime_checkable, TypeVar
1111

1212
import torch
1313
import torch.nn as nn
@@ -37,6 +37,11 @@ def _is_epoch_done(
3737
)
3838

3939

40+
@runtime_checkable
41+
class _DistributedSampler(Protocol):
42+
def set_epoch(self, epoch: int) -> None: ...
43+
44+
4045
def _maybe_set_distributed_sampler_epoch(
4146
dataloader: Iterable[object],
4247
current_epoch: int,
@@ -47,7 +52,7 @@ def _maybe_set_distributed_sampler_epoch(
4752
# Set current training epoch for any DistributedSampler in dataloader
4853
if isinstance(dataloader, torch.utils.data.DataLoader) and isinstance(
4954
dataloader.sampler,
50-
torch.utils.data.distributed.DistributedSampler,
55+
_DistributedSampler,
5156
):
5257
dataloader.sampler.set_epoch(current_epoch)
5358

0 commit comments

Comments
 (0)