13
13
import torch
14
14
from torch .utils .data .dataset import Dataset , TensorDataset
15
15
from torchtnt .utils .data .data_prefetcher import CudaDataPrefetcher
16
- from torchtnt .utils .test_utils import skip_if_not_gpu
17
16
18
17
Batch = Tuple [torch .Tensor , torch .Tensor ]
19
18
20
19
21
- class DataTest (unittest .TestCase ):
20
+ class DataPrefetcherTest (unittest .TestCase ):
22
21
def _generate_dataset (self , num_samples : int , input_dim : int ) -> Dataset [Batch ]:
23
22
"""Returns a dataset of random inputs and labels for binary classification."""
24
23
data = torch .randn (num_samples , input_dim )
25
24
labels = torch .randint (low = 0 , high = 2 , size = (num_samples ,))
26
25
return TensorDataset (data , labels )
27
26
28
- def test_cpu_device_data_prefetcher (self ) -> None :
27
+ def test_device_data_prefetcher (self ) -> None :
29
28
device = torch .device ("cpu" )
30
29
31
30
num_samples = 12
@@ -37,54 +36,3 @@ def test_cpu_device_data_prefetcher(self) -> None:
37
36
num_prefetch_batches = 2
38
37
with self .assertRaisesRegex (ValueError , "expects a CUDA device" ):
39
38
_ = CudaDataPrefetcher (dataloader , device , num_prefetch_batches )
40
-
41
- @skip_if_not_gpu
42
- def test_num_prefetch_batches_data_prefetcher (self ) -> None :
43
- device = torch .device ("cuda:0" )
44
-
45
- num_samples = 12
46
- batch_size = 4
47
- dataloader = torch .utils .data .DataLoader (
48
- self ._generate_dataset (num_samples , 2 ), batch_size = batch_size
49
- )
50
-
51
- with self .assertRaisesRegex (
52
- ValueError , "`num_prefetch_batches` must be greater than 0"
53
- ):
54
- _ = CudaDataPrefetcher (dataloader , device , num_prefetch_batches = - 1 )
55
-
56
- with self .assertRaisesRegex (
57
- ValueError , "`num_prefetch_batches` must be greater than 0"
58
- ):
59
- _ = CudaDataPrefetcher (dataloader , device , num_prefetch_batches = 0 )
60
-
61
- # no exceptions raised
62
- _ = CudaDataPrefetcher (dataloader , device , num_prefetch_batches = 1 )
63
- _ = CudaDataPrefetcher (dataloader , device , num_prefetch_batches = 2 )
64
-
65
- @skip_if_not_gpu
66
- def test_cuda_data_prefetcher (self ) -> None :
67
- device = torch .device ("cuda:0" )
68
-
69
- num_samples = 12
70
- batch_size = 4
71
- dataloader = torch .utils .data .DataLoader (
72
- self ._generate_dataset (num_samples , 2 ), batch_size = batch_size
73
- )
74
-
75
- num_prefetch_batches = 2
76
- data_prefetcher = CudaDataPrefetcher (dataloader , device , num_prefetch_batches )
77
- self .assertEqual (num_prefetch_batches , data_prefetcher .num_prefetch_batches )
78
-
79
- # make sure data_prefetcher has same number of samples as original dataloader
80
- num_batches_in_data_prefetcher = 0
81
- for inputs , targets in data_prefetcher :
82
- num_batches_in_data_prefetcher += 1
83
- # len(inputs) should equal the batch size
84
- self .assertEqual (len (inputs ), batch_size )
85
- self .assertEqual (len (targets ), batch_size )
86
- # make sure batch is on correct device
87
- self .assertEqual (inputs .device , device )
88
- self .assertEqual (targets .device , device )
89
-
90
- self .assertEqual (num_batches_in_data_prefetcher , num_samples / batch_size )
0 commit comments