Skip to content

Commit 8b60aae

Browse files
authored
Add type checks to batch size validators (#36)
* Add type checks to batch size validators
1 parent 222152a commit 8b60aae

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/tiledbsoma_ml/dataset.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch
1212
from attr import evolve
1313
from attrs import define, field
14-
from attrs.validators import gt
14+
from attrs.validators import and_, gt, instance_of
1515
from tiledbsoma import ExperimentAxisQuery
1616
from torch.utils.data import IterableDataset
1717

@@ -117,9 +117,11 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
117117
"""Names of ``obs`` columns to return."""
118118

119119
# Configuration fields with defaults
120-
batch_size: int = field(default=1, validator=gt(0))
120+
batch_size: int = field(default=1, validator=and_(instance_of(int), gt(0)))
121121
"""Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|."""
122-
io_batch_size: int = field(default=DEFAULT_IO_BATCH_SIZE, validator=gt(0))
122+
io_batch_size: int = field(
123+
default=DEFAULT_IO_BATCH_SIZE, validator=and_(instance_of(int), gt(0))
124+
)
123125
"""Number of ``obs``/``X`` rows to fetch together, when reading from the provided |ExperimentAxisQuery|."""
124126
shuffle: bool = field(default=True)
125127
"""Whether to shuffle the ``obs`` and ``X`` data being returned."""

0 commit comments

Comments
 (0)