Skip to content

Commit 52b5568

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
improve dict/Mapping check in copy_data_to_device (#958)
Summary: Pull Request resolved: #958 Reviewed By: diego-urgell Differential Revision: D67962833 fbshipit-source-id: 347f8fd222b96f582f8c7a4e780e29750057b885
1 parent 6e6824c commit 52b5568

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

tests/utils/test_device_gpu.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import dataclasses
1111
import os
1212
import unittest
13-
from collections import defaultdict, namedtuple
13+
from collections import defaultdict, namedtuple, UserDict
1414
from dataclasses import dataclass
1515
from typing import Any, Dict
1616
from unittest import mock
@@ -104,6 +104,21 @@ def test_copy_data_to_device_dict(self) -> None:
104104
for key in new_dict.keys():
105105
self.assertEqual(new_dict[key].device.type, "cuda")
106106

107+
@skip_if_not_gpu
108+
def test_copy_data_to_device_mapping(self) -> None:
109+
cuda_0 = torch.device("cuda:0")
110+
f = torch.tensor([1, 2, 3])
111+
g = torch.tensor([4, 5, 6])
112+
113+
# Use UserDict instead of a regular dictionary
114+
original_dict = UserDict({"f": f, "g": g})
115+
116+
self.assertEqual(f.device.type, "cpu")
117+
self.assertEqual(g.device.type, "cpu")
118+
new_dict = copy_data_to_device(original_dict, cuda_0)
119+
for key in new_dict.keys():
120+
self.assertEqual(new_dict[key].device.type, "cuda")
121+
107122
@skip_if_not_gpu
108123
def test_copy_data_to_device_named_tuple(self) -> None:
109124
cuda_0 = torch.device("cuda:0")

torchtnt/utils/device.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,16 @@ def copy_data_to_device(
8686
for k, v in data.items()
8787
},
8888
)
89-
elif issubclass(data_type, dict):
89+
elif (
90+
hasattr(data, "items")
91+
and hasattr(data, "__getitem__")
92+
and hasattr(data, "__iter__")
93+
):
94+
# pyre-ignore: Too many arguments [19]: Call `object.__init__` expects 0 positional arguments, 1
9095
return data_type(
9196
{
9297
k: copy_data_to_device(v, device, *args, **kwargs)
98+
# pyre-ignore: Undefined attribute [16]: `Variable[T]` has no attribute `items`.
9399
for k, v in data.items()
94100
}
95101
)

0 commit comments

Comments
 (0)