@@ -48,6 +48,7 @@ def __init__(
4848 batch_size : int = 32 ,
4949 num_workers : int = 0 ,
5050 shuffle : bool = False ,
51+ on_after_load_sample : Optional [callable ] = None ,
5152 ):
5253 super ().__init__ ()
5354 self .data = data
@@ -60,6 +61,7 @@ def __init__(
6061 self .batch_size = batch_size
6162 self .num_workers = num_workers
6263 self .shuffle = shuffle
64+ self .on_after_load_sample = on_after_load_sample
6365
6466 @staticmethod
6567 def _get_lengths (fractions , N ):
@@ -72,15 +74,15 @@ def _get_lengths(fractions, N):
7274
7375 def setup (self , stage : str ):
7476 if isinstance (self .data , Samples ):
75- dataset = self .data .get_dataset ()
77+ dataset = self .data .get_dataset (on_after_load_sample = self . on_after_load_sample )
7678 splits = torch .utils .data .random_split (dataset , self .lengths )
7779 self .dataset_train , self .dataset_val , self .dataset_test = splits
7880 elif isinstance (self .data , swyft .ZarrStore ):
7981 idxr1 = (0 , self .lengths [1 ])
8082 idxr2 = (self .lengths [1 ], self .lengths [1 ] + self .lengths [2 ])
8183 idxr3 = (self .lengths [1 ] + self .lengths [2 ], len (self .data ))
8284 self .dataset_train = self .data .get_dataset (
83- idx_range = idxr1 , on_after_load_sample = None
85+ idx_range = idxr1 , on_after_load_sample = self . on_after_load_sample
8486 )
8587 self .dataset_val = self .data .get_dataset (
8688 idx_range = idxr2 , on_after_load_sample = None
0 commit comments