|
9 | 9 |
|
10 | 10 | import unittest
|
11 | 11 | from typing import Tuple
|
| 12 | +from unittest.mock import MagicMock, patch |
12 | 13 |
|
13 | 14 | import torch
|
14 | 15 | from torch.utils.data.dataset import Dataset, TensorDataset
|
@@ -36,3 +37,30 @@ def test_device_data_prefetcher(self) -> None:
|
36 | 37 | num_prefetch_batches = 2
|
37 | 38 | with self.assertRaisesRegex(ValueError, "expects a CUDA device"):
|
38 | 39 | _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches)
|
| 40 | + |
| 41 | + @patch("torch.cuda.Stream") |
| 42 | + def test_num_prefetch_batches_data_prefetcher(self, mock_stream: MagicMock) -> None: |
| 43 | + device = torch.device("cuda:0") |
| 44 | + |
| 45 | + num_samples = 12 |
| 46 | + batch_size = 4 |
| 47 | + dataloader = torch.utils.data.DataLoader( |
| 48 | + self._generate_dataset(num_samples, 2), batch_size=batch_size |
| 49 | + ) |
| 50 | + |
| 51 | + with self.assertRaisesRegex( |
| 52 | + ValueError, "`num_prefetch_batches` must be greater than 0" |
| 53 | + ): |
| 54 | + _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=-1) |
| 55 | + |
| 56 | + with self.assertRaisesRegex( |
| 57 | + ValueError, "`num_prefetch_batches` must be greater than 0" |
| 58 | + ): |
| 59 | + _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=0) |
| 60 | + |
| 61 | + # no exceptions raised |
| 62 | + _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=1) |
| 63 | + _ = CudaDataPrefetcher(dataloader, device, num_prefetch_batches=2) |
| 64 | + |
| 65 | + # Check that CUDA streams were created |
| 66 | + self.assertEqual(mock_stream.call_count, 2) |
0 commit comments