Skip to content

Commit 31d1012

Browse files
committed
Add type checks to batch size validators
1 parent 222152a commit 31d1012

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

src/tiledbsoma_ml/dataset.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from attr import evolve
1313
from attrs import define, field
14-
from attrs.validators import gt
14+
from attrs.validators import and_, instance_of, gt
1515
from tiledbsoma import ExperimentAxisQuery
1616
from torch.utils.data import IterableDataset
1717

@@ -117,9 +117,9 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
117117
"""Names of ``obs`` columns to return."""
118118

119119
# Configuration fields with defaults
120-
batch_size: int = field(default=1, validator=gt(0))
120+
batch_size: int = field(default=1, validator=and_(instance_of(int), gt(0)))
121121
"""Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|."""
122-
io_batch_size: int = field(default=DEFAULT_IO_BATCH_SIZE, validator=gt(0))
122+
io_batch_size: int = field(default=DEFAULT_IO_BATCH_SIZE, validator=and_(instance_of(int), gt(0)))
123123
"""Number of ``obs``/``X`` rows to fetch together, when reading from the provided |ExperimentAxisQuery|."""
124124
shuffle: bool = field(default=True)
125125
"""Whether to shuffle the ``obs`` and ``X`` data being returned."""
@@ -223,9 +223,7 @@ def __init__(
223223
"""
224224
if query and layer_name:
225225
if x_locator or query_ids:
226-
raise ValueError(
227-
"Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
228-
)
226+
raise ValueError("Expected `{query,layer_name}` xor `{x_locator,query_ids}`")
229227
query_ids = QueryIDs.create(query=query)
230228
x_locator = XLocator.create(
231229
query.experiment,
@@ -234,13 +232,9 @@ def __init__(
234232
)
235233
elif x_locator and query_ids:
236234
if query or layer_name:
237-
raise ValueError(
238-
"Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
239-
)
235+
raise ValueError("Expected `{query,layer_name}` xor `{x_locator,query_ids}`")
240236
else:
241-
raise ValueError(
242-
"Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
243-
)
237+
raise ValueError("Expected `{query,layer_name}` xor `{x_locator,query_ids}`")
244238

245239
self.__attrs_init__(
246240
x_locator=x_locator,
@@ -264,14 +258,10 @@ def __attrs_post_init__(self) -> None:
264258
if self.shuffle:
265259
# Verify `io_batch_size` is a multiple of `shuffle_chunk_size`
266260
if self.io_batch_size % self.shuffle_chunk_size:
267-
raise ValueError(
268-
f"{self.io_batch_size=} is not a multiple of {self.shuffle_chunk_size=}"
269-
)
261+
raise ValueError(f"{self.io_batch_size=} is not a multiple of {self.shuffle_chunk_size=}")
270262

271263
if self.seed is None:
272-
object.__setattr__(
273-
self, "seed", np.random.default_rng().integers(0, 2**32 - 1)
274-
)
264+
object.__setattr__(self, "seed", np.random.default_rng().integers(0, 2**32 - 1))
275265

276266
# Set distributed state
277267
rank, world_size = get_distributed_rank_and_world_size()
@@ -425,6 +415,4 @@ def set_epoch(self, epoch: int) -> None:
425415
self.epoch = epoch
426416

427417
def __getitem__(self, index: int) -> MiniBatch:
428-
raise NotImplementedError(
429-
"`Experiment` can only be iterated - does not support mapping"
430-
)
418+
raise NotImplementedError("`Experiment` can only be iterated - does not support mapping")

0 commit comments

Comments
 (0)