Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 9 additions & 21 deletions src/tiledbsoma_ml/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch
from attr import evolve
from attrs import define, field
from attrs.validators import gt
from attrs.validators import and_, instance_of, gt
from tiledbsoma import ExperimentAxisQuery
from torch.utils.data import IterableDataset

Expand Down Expand Up @@ -117,9 +117,9 @@ class ExperimentDataset(IterableDataset[MiniBatch]): # type: ignore[misc]
"""Names of ``obs`` columns to return."""

# Configuration fields with defaults
batch_size: int = field(default=1, validator=gt(0))
batch_size: int = field(default=1, validator=and_(instance_of(int), gt(0)))
"""Number of rows of ``X`` and ``obs`` data to yield in each |MiniBatch|."""
io_batch_size: int = field(default=DEFAULT_IO_BATCH_SIZE, validator=gt(0))
io_batch_size: int = field(default=DEFAULT_IO_BATCH_SIZE, validator=and_(instance_of(int), gt(0)))
"""Number of ``obs``/``X`` rows to fetch together, when reading from the provided |ExperimentAxisQuery|."""
shuffle: bool = field(default=True)
"""Whether to shuffle the ``obs`` and ``X`` data being returned."""
Expand Down Expand Up @@ -223,9 +223,7 @@ def __init__(
"""
if query and layer_name:
if x_locator or query_ids:
raise ValueError(
"Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
)
raise ValueError("Expected `{query,layer_name}` xor `{x_locator,query_ids}`")
query_ids = QueryIDs.create(query=query)
x_locator = XLocator.create(
query.experiment,
Expand All @@ -234,13 +232,9 @@ def __init__(
)
elif x_locator and query_ids:
if query or layer_name:
raise ValueError(
"Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
)
raise ValueError("Expected `{query,layer_name}` xor `{x_locator,query_ids}`")
else:
raise ValueError(
"Expected `{query,layer_name}` xor `{x_locator,query_ids}`"
)
raise ValueError("Expected `{query,layer_name}` xor `{x_locator,query_ids}`")

self.__attrs_init__(
x_locator=x_locator,
Expand All @@ -264,14 +258,10 @@ def __attrs_post_init__(self) -> None:
if self.shuffle:
# Verify `io_batch_size` is a multiple of `shuffle_chunk_size`
if self.io_batch_size % self.shuffle_chunk_size:
raise ValueError(
f"{self.io_batch_size=} is not a multiple of {self.shuffle_chunk_size=}"
)
raise ValueError(f"{self.io_batch_size=} is not a multiple of {self.shuffle_chunk_size=}")

if self.seed is None:
object.__setattr__(
self, "seed", np.random.default_rng().integers(0, 2**32 - 1)
)
object.__setattr__(self, "seed", np.random.default_rng().integers(0, 2**32 - 1))

# Set distributed state
rank, world_size = get_distributed_rank_and_world_size()
Expand Down Expand Up @@ -425,6 +415,4 @@ def set_epoch(self, epoch: int) -> None:
self.epoch = epoch

def __getitem__(self, index: int) -> MiniBatch:
raise NotImplementedError(
"`Experiment` can only be iterated - does not support mapping"
)
raise NotImplementedError("`Experiment` can only be iterated - does not support mapping")
Loading