1111import torch
1212from attr import evolve
1313from attrs import define , field
14- from attrs .validators import gt
14+ from attrs .validators import and_ , instance_of , gt
1515from tiledbsoma import ExperimentAxisQuery
1616from torch .utils .data import IterableDataset
1717
@@ -117,9 +117,9 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
117117 """Names of ``obs`` columns to return."""
118118
119119 # Configuration fields with defaults
120- batch_size : int = field (default = 1 , validator = gt (0 ))
120+ batch_size : int = field (default = 1 , validator = and_ ( instance_of ( int ), gt (0 ) ))
121121 """Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|."""
122- io_batch_size : int = field (default = DEFAULT_IO_BATCH_SIZE , validator = gt (0 ))
122+ io_batch_size : int = field (default = DEFAULT_IO_BATCH_SIZE , validator = and_ ( instance_of ( int ), gt (0 ) ))
123123 """Number of ``obs``/``X`` rows to fetch together, when reading from the provided |ExperimentAxisQuery|."""
124124 shuffle : bool = field (default = True )
125125 """Whether to shuffle the ``obs`` and ``X`` data being returned."""
@@ -223,9 +223,7 @@ def __init__(
223223 """
224224 if query and layer_name :
225225 if x_locator or query_ids :
226- raise ValueError (
227- "Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
228- )
226+ raise ValueError ("Expected `{query,layer_name}` xor `{x_locator,query_ids}`" )
229227 query_ids = QueryIDs .create (query = query )
230228 x_locator = XLocator .create (
231229 query .experiment ,
@@ -234,13 +232,9 @@ def __init__(
234232 )
235233 elif x_locator and query_ids :
236234 if query or layer_name :
237- raise ValueError (
238- "Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
239- )
235+ raise ValueError ("Expected `{query,layer_name}` xor `{x_locator,query_ids}`" )
240236 else :
241- raise ValueError (
242- "Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
243- )
237+ raise ValueError ("Expected `{query,layer_name}` xor `{x_locator,query_ids}`" )
244238
245239 self .__attrs_init__ (
246240 x_locator = x_locator ,
@@ -264,14 +258,10 @@ def __attrs_post_init__(self) -> None:
264258 if self .shuffle :
265259 # Verify `io_batch_size` is a multiple of `shuffle_chunk_size`
266260 if self .io_batch_size % self .shuffle_chunk_size :
267- raise ValueError (
268- f"{ self .io_batch_size = } is not a multiple of { self .shuffle_chunk_size = } "
269- )
261+ raise ValueError (f"{ self .io_batch_size = } is not a multiple of { self .shuffle_chunk_size = } " )
270262
271263 if self .seed is None :
272- object .__setattr__ (
273- self , "seed" , np .random .default_rng ().integers (0 , 2 ** 32 - 1 )
274- )
264+ object .__setattr__ (self , "seed" , np .random .default_rng ().integers (0 , 2 ** 32 - 1 ))
275265
276266 # Set distributed state
277267 rank , world_size = get_distributed_rank_and_world_size ()
@@ -425,6 +415,4 @@ def set_epoch(self, epoch: int) -> None:
425415 self .epoch = epoch
426416
427417 def __getitem__ (self , index : int ) -> MiniBatch :
428- raise NotImplementedError (
429- "`Experiment` can only be iterated - does not support mapping"
430- )
418+ raise NotImplementedError ("`Experiment` can only be iterated - does not support mapping" )
0 commit comments