Skip to content

Commit 4d64cb8

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Split data_prefetcher GPU tests into a different file (#772)
Summary: Pull Request resolved: #772 Reviewed By: JKSenthil Differential Revision: D55495109 fbshipit-source-id: 0e2048d8b2744ccd306817d69c1bd189abaa24a9
1 parent 94344e2 commit 4d64cb8

File tree

2 files changed

+79
-54
lines changed

2 files changed

+79
-54
lines changed

tests/utils/data/test_data_prefetcher.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,18 @@
1313
import torch
1414
from torch.utils.data.dataset import Dataset, TensorDataset
1515
from torchtnt.utils.data.data_prefetcher import CudaDataPrefetcher
16-
from torchtnt.utils.test_utils import skip_if_not_gpu
1716

1817
Batch = Tuple[torch.Tensor, torch.Tensor]
1918

2019

21-
class DataTest(unittest.TestCase):
20+
class DataPrefetcherTest(unittest.TestCase):
2221
def _generate_dataset(self, num_samples: int, input_dim: int) -> Dataset[Batch]:
2322
"""Returns a dataset of random inputs and labels for binary classification."""
2423
data = torch.randn(num_samples, input_dim)
2524
labels = torch.randint(low=0, high=2, size=(num_samples,))
2625
return TensorDataset(data, labels)
2726

28-
def test_cpu_device_data_prefetcher(self) -> None:
27+
def test_device_data_prefetcher(self) -> None:
2928
device = torch.device("cpu")
3029

3130
num_samples = 12
@@ -37,54 +36,3 @@ def test_cpu_device_data_prefetcher(self) -> None:
3736
num_prefetch_batches = 2
3837
with self.assertRaisesRegex(ValueError, "expects a CUDA device"):
3938
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches)
40-
41-
@skip_if_not_gpu
42-
def test_num_prefetch_batches_data_prefetcher(self) -> None:
43-
device = torch.device("cuda:0")
44-
45-
num_samples = 12
46-
batch_size = 4
47-
dataloader = torch.utils.data.DataLoader(
48-
self._generate_dataset(num_samples, 2), batch_size=batch_size
49-
)
50-
51-
with self.assertRaisesRegex(
52-
ValueError, "`num_prefetch_batches` must be greater than 0"
53-
):
54-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=-1)
55-
56-
with self.assertRaisesRegex(
57-
ValueError, "`num_prefetch_batches` must be greater than 0"
58-
):
59-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=0)
60-
61-
# no exceptions raised
62-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1)
63-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2)
64-
65-
@skip_if_not_gpu
66-
def test_cuda_data_prefetcher(self) -> None:
67-
device = torch.device("cuda:0")
68-
69-
num_samples = 12
70-
batch_size = 4
71-
dataloader = torch.utils.data.DataLoader(
72-
self._generate_dataset(num_samples, 2), batch_size=batch_size
73-
)
74-
75-
num_prefetch_batches = 2
76-
data_prefetcher = CudaDataPrefetcher(dataloader, device, num_prefetch_batches)
77-
self.assertEqual(num_prefetch_batches, data_prefetcher.num_prefetch_batches)
78-
79-
# make sure data_prefetcher has same number of samples as original dataloader
80-
num_batches_in_data_prefetcher = 0
81-
for inputs, targets in data_prefetcher:
82-
num_batches_in_data_prefetcher += 1
83-
# len(inputs) should equal the batch size
84-
self.assertEqual(len(inputs), batch_size)
85-
self.assertEqual(len(targets), batch_size)
86-
# make sure batch is on correct device
87-
self.assertEqual(inputs.device, device)
88-
self.assertEqual(targets.device, device)
89-
90-
self.assertEqual(num_batches_in_data_prefetcher, num_samples / batch_size)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
import unittest
10+
from typing import Tuple
11+
12+
import torch
13+
14+
from torch.utils.data import Dataset, TensorDataset
15+
from torchtnt.utils.data.data_prefetcher import CudaDataPrefetcher
16+
from torchtnt.utils.test_utils import skip_if_not_gpu
17+
18+
Batch = Tuple[torch.Tensor, torch.Tensor]
19+
20+
21+
class DataPrefetcherGPUTest(unittest.TestCase):
22+
def _generate_dataset(self, num_samples: int, input_dim: int) -> Dataset[Batch]:
23+
"""Returns a dataset of random inputs and labels for binary classification."""
24+
data = torch.randn(num_samples, input_dim)
25+
labels = torch.randint(low=0, high=2, size=(num_samples,))
26+
return TensorDataset(data, labels)
27+
28+
@skip_if_not_gpu
29+
def test_num_prefetch_batches_data_prefetcher(self) -> None:
30+
device = torch.device("cuda:0")
31+
32+
num_samples = 12
33+
batch_size = 4
34+
dataloader = torch.utils.data.DataLoader(
35+
self._generate_dataset(num_samples, 2), batch_size=batch_size
36+
)
37+
38+
with self.assertRaisesRegex(
39+
ValueError, "`num_prefetch_batches` must be greater than 0"
40+
):
41+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=-1)
42+
43+
with self.assertRaisesRegex(
44+
ValueError, "`num_prefetch_batches` must be greater than 0"
45+
):
46+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=0)
47+
48+
# no exceptions raised
49+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1)
50+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2)
51+
52+
@skip_if_not_gpu
53+
def test_cuda_data_prefetcher(self) -> None:
54+
device = torch.device("cuda:0")
55+
56+
num_samples = 12
57+
batch_size = 4
58+
dataloader = torch.utils.data.DataLoader(
59+
self._generate_dataset(num_samples, 2), batch_size=batch_size
60+
)
61+
62+
num_prefetch_batches = 2
63+
data_prefetcher = CudaDataPrefetcher(dataloader, device, num_prefetch_batches)
64+
self.assertEqual(num_prefetch_batches, data_prefetcher.num_prefetch_batches)
65+
66+
# make sure data_prefetcher has same number of samples as original dataloader
67+
num_batches_in_data_prefetcher = 0
68+
for inputs, targets in data_prefetcher:
69+
num_batches_in_data_prefetcher += 1
70+
# len(inputs) should equal the batch size
71+
self.assertEqual(len(inputs), batch_size)
72+
self.assertEqual(len(targets), batch_size)
73+
# make sure batch is on correct device
74+
self.assertEqual(inputs.device, device)
75+
self.assertEqual(targets.device, device)
76+
77+
self.assertEqual(num_batches_in_data_prefetcher, num_samples / batch_size)

0 commit comments

Comments
 (0)