Skip to content

Commit 244a3e9

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Fix data_prefetcher test error (#774)
Summary: Pull Request resolved: #774 Reviewed By: galrotem Differential Revision: D55514862 fbshipit-source-id: 43c0bb7d549dfeba4e57efb3c9b29dd025efe5cb
1 parent 4d64cb8 commit 244a3e9

File tree

2 files changed

+28
-24
lines changed

2 files changed

+28
-24
lines changed

tests/utils/data/test_data_prefetcher.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import unittest
1111
from typing import Tuple
12+
from unittest.mock import MagicMock, patch
1213

1314
import torch
1415
from torch.utils.data.dataset import Dataset, TensorDataset
@@ -36,3 +37,30 @@ def test_device_data_prefetcher(self) -> None:
3637
num_prefetch_batches = 2
3738
with self.assertRaisesRegex(ValueError, "expects a CUDA device"):
3839
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches)
40+
41+
@patch("torch.cuda.Stream")
42+
def test_num_prefetch_batches_data_prefetcher(self, mock_stream: MagicMock) -> None:
43+
device = torch.device("cuda:0")
44+
45+
num_samples = 12
46+
batch_size = 4
47+
dataloader = torch.utils.data.DataLoader(
48+
self._generate_dataset(num_samples, 2), batch_size=batch_size
49+
)
50+
51+
with self.assertRaisesRegex(
52+
ValueError, "`num_prefetch_batches` must be greater than 0"
53+
):
54+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=-1)
55+
56+
with self.assertRaisesRegex(
57+
ValueError, "`num_prefetch_batches` must be greater than 0"
58+
):
59+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=0)
60+
61+
# no exceptions raised
62+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1)
63+
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2)
64+
65+
# Check that CUDA streams were created
66+
self.assertEqual(mock_stream.call_count, 2)

tests/utils/data/test_data_prefetcher_gpu.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -25,30 +25,6 @@ def _generate_dataset(self, num_samples: int, input_dim: int) -> Dataset[Batch]:
2525
labels = torch.randint(low=0, high=2, size=(num_samples,))
2626
return TensorDataset(data, labels)
2727

28-
@skip_if_not_gpu
29-
def test_num_prefetch_batches_data_prefetcher(self) -> None:
30-
device = torch.device("cuda:0")
31-
32-
num_samples = 12
33-
batch_size = 4
34-
dataloader = torch.utils.data.DataLoader(
35-
self._generate_dataset(num_samples, 2), batch_size=batch_size
36-
)
37-
38-
with self.assertRaisesRegex(
39-
ValueError, "`num_prefetch_batches` must be greater than 0"
40-
):
41-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=-1)
42-
43-
with self.assertRaisesRegex(
44-
ValueError, "`num_prefetch_batches` must be greater than 0"
45-
):
46-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=0)
47-
48-
# no exceptions raised
49-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1)
50-
_ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2)
51-
5228
@skip_if_not_gpu
5329
def test_cuda_data_prefetcher(self) -> None:
5430
device = torch.device("cuda:0")

0 commit comments

Comments
 (0)