Skip to content

Commit af56155

Browse files
committed
Remove deprecation warning, add on_after_load_sample option
modified: data.py modified: simulator.py
1 parent 47093fa commit af56155

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

swyft/lightning/data.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(
4848
batch_size: int = 32,
4949
num_workers: int = 0,
5050
shuffle: bool = False,
51+
on_after_load_sample: Optional[callable] = None,
5152
):
5253
super().__init__()
5354
self.data = data
@@ -60,6 +61,7 @@ def __init__(
6061
self.batch_size = batch_size
6162
self.num_workers = num_workers
6263
self.shuffle = shuffle
64+
self.on_after_load_sample = on_after_load_sample
6365

6466
@staticmethod
6567
def _get_lengths(fractions, N):
@@ -72,15 +74,15 @@ def _get_lengths(fractions, N):
7274

7375
def setup(self, stage: str):
7476
if isinstance(self.data, Samples):
75-
dataset = self.data.get_dataset()
77+
dataset = self.data.get_dataset(on_after_load_sample = self.on_after_load_sample)
7678
splits = torch.utils.data.random_split(dataset, self.lengths)
7779
self.dataset_train, self.dataset_val, self.dataset_test = splits
7880
elif isinstance(self.data, swyft.ZarrStore):
7981
idxr1 = (0, self.lengths[1])
8082
idxr2 = (self.lengths[1], self.lengths[1] + self.lengths[2])
8183
idxr3 = (self.lengths[1] + self.lengths[2], len(self.data))
8284
self.dataset_train = self.data.get_dataset(
83-
idx_range=idxr1, on_after_load_sample=None
85+
idx_range=idxr1, on_after_load_sample=self.on_after_load_sample
8486
)
8587
self.dataset_val = self.data.get_dataset(
8688
idx_range=idxr2, on_after_load_sample=None

swyft/lightning/simulator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,14 @@ def get_dataloader(
7676
repeat=None,
7777
num_workers=0,
7878
):
79-
"""(Deprecated) Generator function to directly generate a dataloader object.
79+
"""Generator function to directly generate a dataloader object.
8080
8181
Args:
8282
batch_size: batch_size for dataloader
8383
shuffle: shuffle for dataloader
8484
on_after_load_sample: see `get_dataset`
8585
repeat: If not None, Wrap dataset in RepeatDatasetWrapper
8686
"""
87-
print("WARNING: Deprecated")
8887
dataset = self.get_dataset(on_after_load_sample=on_after_load_sample)
8988
if repeat is not None:
9089
dataset = swyft.lightning.data.RepeatDatasetWrapper(dataset, repeat=repeat)

0 commit comments

Comments
 (0)