Skip to content

Commit 0ba17a6

Browse files
committed
Merge branch 'main' of github.com:tensorlayer/TensorLayerX into main
2 parents c9c0c40 + 2a37ec2 commit 0ba17a6

File tree

11 files changed

+1468
-729
lines changed

11 files changed

+1468
-729
lines changed

docs/modules/dataflow.rst

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,32 @@ Dataflow list
1212

1313
.. autosummary::
1414

15+
DataLoader
1516
Dataset
1617
IterableDataset
17-
FromGenerator
18-
FromSlices
19-
Dataloader
18+
TensorDataset
19+
ChainDataset
20+
ConcatDataset
21+
Subset
22+
random_split
23+
Sampler
24+
BatchSampler
25+
RandomSampler
26+
SequentialSampler
27+
WeightedRandomSampler
28+
SubsetRandomSampler
2029

21-
Concat
22-
Zip
23-
Batch
24-
Map
25-
Repeat
26-
Shuffle
2730

2831
.. -----------------------------------------------------------
2932
.. Dataflow
3033
.. -----------------------------------------------------------
3134
3235
Dataflow
3336
-----------------
37+
DataLoader
38+
^^^^^^^^^^^^^^^^
39+
.. autoclass:: DataLoader
40+
3441

3542
Dataset
3643
^^^^^^^^^^^^^^^^
@@ -41,39 +48,46 @@ IterableDataset
4148
^^^^^^^^^^^^^^^^
4249
.. autoclass:: IterableDataset
4350

44-
FromGenerator
51+
TensorDataset
4552
^^^^^^^^^^^^^^^^
46-
.. autoclass:: FromGenerator
53+
.. autoclass:: TensorDataset
4754

48-
FromSlices
55+
ChainDataset
4956
^^^^^^^^^^^^^^^^
50-
.. autoclass:: FromSlices
57+
.. autoclass:: ChainDataset
5158

52-
Dataloader
59+
ConcatDataset
5360
^^^^^^^^^^^^^^^^
54-
.. autoclass:: Dataloader
61+
.. autoclass:: ConcatDataset
5562

56-
Concat
63+
Subset
5764
^^^^^^^^^^^^^^^^
58-
.. autoclass:: Concat
65+
.. autoclass:: Subset
5966

60-
Zip
67+
random_split
6168
^^^^^^^^^^^^^^^^
62-
.. autoclass:: Zip
69+
.. autoclass:: random_split
6370

64-
Batch
71+
Sampler
6572
^^^^^^^^^^^^^^^^
66-
.. autoclass:: Batch
73+
.. autoclass:: Sampler
6774

68-
Map
69-
^^^^^^^^^^^^^^^^^^^^^
70-
.. autoclass:: Map
75+
BatchSampler
76+
^^^^^^^^^^^^^^^^
77+
.. autoclass:: BatchSampler
7178

72-
Repeat
73-
^^^^^^^^^^^^^^^^^^^^^
74-
.. autoclass:: Repeat
79+
RandomSampler
80+
^^^^^^^^^^^^^^^^
81+
.. autoclass:: RandomSampler
7582

76-
Shuffle
83+
SequentialSampler
7784
^^^^^^^^^^^^^^^^^^^^^
78-
.. autoclass:: Shuffle
85+
.. autoclass:: SequentialSampler
86+
87+
WeightedRandomSampler
88+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
89+
.. autoclass:: WeightedRandomSampler
7990

91+
SubsetRandomSampler
92+
^^^^^^^^^^^^^^^^^^^^^^^^^^^
93+
.. autoclass:: SubsetRandomSampler

tensorlayerx/dataflow/__init__.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,6 @@
11
#! /usr/bin/python
22
# -*- coding: utf-8 -*-
33
from __future__ import absolute_import, division, print_function
4-
5-
from tensorlayerx.backend.ops.load_backend import BACKEND
6-
7-
if BACKEND == 'tensorflow':
8-
from .tensorflow_data import *
9-
10-
elif BACKEND == 'mindspore':
11-
from .mindspore_data import *
12-
13-
elif BACKEND == 'paddle':
14-
from .paddle_data import *
15-
16-
elif BACKEND == 'torch':
17-
from .torch_data import *
18-
19-
else:
20-
raise NotImplementedError("This backend is not supported")
4+
from .dataloader import *
5+
from .sampler import *
6+
from .dataset import *
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
#! /usr/bin/python
2+
# -*- coding: utf-8 -*-
3+
from .dataset import Dataset, IterableDataset
4+
from .sampler import Sampler, SequentialSampler, RandomSampler, BatchSampler, SubsetRandomSampler, WeightedRandomSampler
5+
from .utils import _DatasetKind, _InfiniteIterableSampler
6+
from . import utils
7+
import math
8+
__all__ = [
9+
'DataLoader',
10+
]
11+
12+
13+
class DataLoader(object):
14+
""" Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
15+
16+
The :class:`tensorlayerx.dataflow.DataLoader` supports both map-style and
17+
iterable-style datasets with single- or multi-process loading, customizing
18+
loading order and optional automatic batching
19+
20+
Parameters
21+
-----------
22+
dataset : Dataset
23+
dataset from which to load the data.
24+
batch_size : int
25+
how many samples per batch to load, default is 1.
26+
shuffle : bool
27+
set to ``True`` to have the data reshuffled at every epoch, default is ``False``.
28+
drop_last : bool
29+
set to ``True`` to drop the last incomplete batch,
30+
if the dataset size is not divisible by the batch size. If ``False`` and
31+
the size of dataset is not divisible by the batch size, then the last batch
32+
will be smaller. default is ``False``.
33+
sampler : Sampler
34+
defines the strategy to draw samples from the dataset. If specified, `shuffle` must not be specified.
35+
batch_sampler : Sampler
36+
returns a batch of indices at a time. If specified, `shuffle`, `batch_size`, `drop_last`, `sampler` must not be specified.
37+
num_workers : int
38+
how many subprocesses to use for data loading. ``0`` means that the data will be loaded in single process. default is ``0``.
39+
collate_fn : callable
40+
merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
41+
time_out : numeric
42+
if positive, the timeout value for collecting a batch from workers. Should always be non-negative. default is ``0``.
43+
worker_init_fn : callable
44+
If not ``None``, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as
45+
input, after seeding and before data loading. default is ``None``.
46+
prefetch_factor : int
47+
Number of samples loaded in advance by each worker.
48+
``2`` means there will be a total of 2 * num_workers samples prefetched across all workers. default is ``2``
49+
persistent_workers : bool
50+
If ``True``, the data loader will not shutdown the worker processes after a dataset has been consumed once.
51+
This allows to maintain the workers `Dataset` instances alive. default is ``False``.
52+
"""
53+
54+
def __init__(
55+
self,
56+
dataset,
57+
batch_size=1,
58+
shuffle=False,
59+
drop_last=False,
60+
sampler=None,
61+
batch_sampler=None,
62+
num_workers=0,
63+
collate_fn=None,
64+
# pin_memory = False,
65+
time_out=0,
66+
worker_init_fn=None,
67+
#multiprocessing_context=None,
68+
prefetch_factor=2,
69+
persistent_workers=False,
70+
):
71+
# assert isinstance(dataset, Dataset), "dataset should be subclass of tensorlayerx.dataflow.Dataset"
72+
self.dataset = dataset
73+
assert num_workers >= 0, "num_workers should be a non_negative integer"
74+
if num_workers == 0 and prefetch_factor != 2:
75+
raise ValueError("prefetch_factor option should not be specified, when num_workers is 0.")
76+
if persistent_workers and num_workers == 0:
77+
raise ValueError('persistent_workers option needs num_workers > 0')
78+
self.num_workers = num_workers
79+
self.prefetch_factor = prefetch_factor
80+
# self.pin_memory = pin_memory
81+
self.time_out = time_out
82+
self.worker_init_fn = worker_init_fn
83+
#self.multiprocessing_context = multiprocessing_context
84+
if isinstance(dataset, IterableDataset):
85+
self._dataset_kind = _DatasetKind.Iter
86+
if shuffle is not False:
87+
raise ValueError("IterableDataset only support 'shuffle=False', but got shuffle={}.".format(shuffle))
88+
elif sampler is not None:
89+
raise ValueError("IterableDataset only support 'sampler=None', but got sampler={}.".format(sampler))
90+
elif batch_sampler is not None:
91+
raise ValueError(
92+
"IterableDataset only support 'batch_sampler=None', "
93+
"but got batch_sampler={}.".format(batch_sampler)
94+
)
95+
else:
96+
self._dataset_kind = _DatasetKind.Map
97+
98+
if sampler is not None and shuffle:
99+
raise ValueError("sampler option is mutually exclusive with shuffle option.")
100+
101+
if batch_sampler is not None:
102+
if batch_size != 1 or shuffle or sampler is not None or drop_last:
103+
raise ValueError(
104+
"batch_size, shuffle, sampler, drop_last should not be set, when batch_sampler is specified."
105+
)
106+
batch_size = None
107+
drop_last = False
108+
elif batch_size is None:
109+
if drop_last:
110+
raise ValueError("drop_last should be False, when batch_size is None.")
111+
112+
if sampler is None:
113+
if self._dataset_kind == _DatasetKind.Iter:
114+
sampler = _InfiniteIterableSampler()
115+
else:
116+
if shuffle:
117+
sampler = RandomSampler(dataset)
118+
else:
119+
sampler = SequentialSampler(dataset)
120+
121+
if batch_size is not None and batch_sampler is None:
122+
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
123+
124+
self.batch_size = batch_size
125+
self.drop_last = drop_last
126+
self.sampler = sampler
127+
self.batch_sampler = batch_sampler
128+
self._iterator = None
129+
if collate_fn is None:
130+
if self._is_batch:
131+
collate_fn = utils.default_collate
132+
else:
133+
collate_fn = utils.default_convert
134+
135+
self.collate_fn = collate_fn
136+
self.persistent_workers = persistent_workers
137+
138+
@property
139+
def _is_batch(self):
140+
return self.batch_sampler is not None
141+
142+
@property
143+
def _index_sampler(self):
144+
if self._is_batch:
145+
return self.batch_sampler
146+
else:
147+
return self.sampler
148+
149+
def _get_iterator(self):
150+
if self.num_workers == 0:
151+
return utils._SingleProcessDataLoaderIter(self)
152+
else:
153+
return utils._MultiProcessingDataLoaderIter(self)
154+
155+
def __iter__(self):
156+
157+
if self.persistent_workers and self.num_workers > 0:
158+
if self._iterator is None:
159+
160+
self._iterator = self._get_iterator()
161+
else:
162+
self._iterator._reset(self)
163+
return self._iterator
164+
else:
165+
return self._get_iterator()
166+
167+
def __len__(self):
168+
if self._dataset_kind == _DatasetKind.Iter:
169+
length = len(self.dataset)
170+
if self.batch_size is not None:
171+
if self.drop_last:
172+
length = length // self.batch_size
173+
else:
174+
length = math.ceil(length / self.batch_size)
175+
return length
176+
else:
177+
return len(self._index_sampler)

0 commit comments

Comments
 (0)