diff --git a/src/tiledbsoma_ml/dataset.py b/src/tiledbsoma_ml/dataset.py index 80b74a7..42b7d7b 100644 --- a/src/tiledbsoma_ml/dataset.py +++ b/src/tiledbsoma_ml/dataset.py @@ -11,7 +11,7 @@ import torch from attr import evolve from attrs import define, field -from attrs.validators import gt +from attrs.validators import and_, gt, instance_of from tiledbsoma import ExperimentAxisQuery from torch.utils.data import IterableDataset @@ -117,9 +117,11 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc] """Names of ``obs`` columns to return.""" # Configuration fields with defaults - batch_size: int = field(default=1, validator=gt(0)) + batch_size: int = field(default=1, validator=and_(instance_of(int), gt(0))) """Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|.""" - io_batch_size: int = field(default=DEFAULT_IO_BATCH_SIZE, validator=gt(0)) + io_batch_size: int = field( + default=DEFAULT_IO_BATCH_SIZE, validator=and_(instance_of(int), gt(0)) + ) """Number of ``obs``/``X`` rows to fetch together, when reading from the provided |ExperimentAxisQuery|.""" shuffle: bool = field(default=True) """Whether to shuffle the ``obs`` and ``X`` data being returned."""