|
10 | 10 | import dataclasses
|
11 | 11 | import os
|
12 | 12 | import unittest
|
13 |
| -from collections import defaultdict, namedtuple |
| 13 | +from collections import defaultdict, namedtuple, UserDict |
14 | 14 | from dataclasses import dataclass
|
15 | 15 | from typing import Any, Dict
|
16 | 16 | from unittest import mock
|
@@ -104,6 +104,21 @@ def test_copy_data_to_device_dict(self) -> None:
|
104 | 104 | for key in new_dict.keys():
|
105 | 105 | self.assertEqual(new_dict[key].device.type, "cuda")
|
106 | 106 |
|
| 107 | + @skip_if_not_gpu |
| 108 | + def test_copy_data_to_device_mapping(self) -> None: |
| 109 | + cuda_0 = torch.device("cuda:0") |
| 110 | + f = torch.tensor([1, 2, 3]) |
| 111 | + g = torch.tensor([4, 5, 6]) |
| 112 | + |
| 113 | + # Use UserDict instead of a regular dictionary |
| 114 | + original_dict = UserDict({"f": f, "g": g}) |
| 115 | + |
| 116 | + self.assertEqual(f.device.type, "cpu") |
| 117 | + self.assertEqual(g.device.type, "cpu") |
| 118 | + new_dict = copy_data_to_device(original_dict, cuda_0) |
| 119 | + for key in new_dict.keys(): |
| 120 | + self.assertEqual(new_dict[key].device.type, "cuda") |
| 121 | + |
107 | 122 | @skip_if_not_gpu
|
108 | 123 | def test_copy_data_to_device_named_tuple(self) -> None:
|
109 | 124 | cuda_0 = torch.device("cuda:0")
|
|
0 commit comments