Skip to content

Commit c8a8e76

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
optimize extracting tensors in copy_data_to_device (#955)
Summary: Pull Request resolved: #955 Reviewed By: galrotem Differential Revision: D67719962 fbshipit-source-id: b1c9d9e8bf722734c4e7177539671462ecd4fb96
1 parent ed30bb6 commit c8a8e76

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

torchtnt/utils/device.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,6 @@ def get_device_from_env() -> torch.device:
5252
TSelf = TypeVar("TSelf")
5353

5454

55-
@runtime_checkable
56-
class _CopyableData(Protocol):
57-
def to(self: TSelf, device: torch.device, *args: Any, **kwargs: Any) -> TSelf:
58-
"""Copy data to the specified device"""
59-
...
60-
61-
6255
def _is_named_tuple(x: T) -> bool:
6356
return isinstance(x, tuple) and hasattr(x, "_asdict") and hasattr(x, "_fields")
6457

@@ -76,31 +69,33 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
7669
The data on the correct device
7770
"""
7871

79-
# Redundant isinstance(data, tuple) check is required here to make pyre happy
80-
if _is_named_tuple(data) and isinstance(data, tuple):
81-
return type(data)(
82-
**copy_data_to_device(data._asdict(), device, *args, **kwargs)
83-
)
84-
elif isinstance(data, (list, tuple)):
85-
return type(data)(copy_data_to_device(e, device, *args, **kwargs) for e in data)
86-
elif isinstance(data, defaultdict):
87-
return type(data)(
72+
data_type = type(data)
73+
if issubclass(data_type, defaultdict):
74+
return data_type(
8875
data.default_factory,
8976
{
9077
k: copy_data_to_device(v, device, *args, **kwargs)
9178
for k, v in data.items()
9279
},
9380
)
94-
elif isinstance(data, Mapping):
95-
return type(data)(
81+
elif issubclass(data_type, dict):
82+
return data_type(
9683
{
9784
k: copy_data_to_device(v, device, *args, **kwargs)
9885
for k, v in data.items()
9986
}
10087
)
101-
elif is_dataclass(data) and not isinstance(data, type):
102-
# pyre-fixme[45]: Cannot instantiate protocol `DataclassInstance`.
103-
new_data_class = type(data)(
88+
elif issubclass(data_type, list):
89+
return data_type(copy_data_to_device(e, device, *args, **kwargs) for e in data)
90+
elif issubclass(data_type, tuple):
91+
if hasattr(data, "_asdict") and hasattr(data, "_fields"):
92+
return data_type(
93+
**copy_data_to_device(data._asdict(), device, *args, **kwargs)
94+
)
95+
return data_type(copy_data_to_device(e, device, *args, **kwargs) for e in data)
96+
# checking for __dataclass_fields__ is official way to check if data is a dataclass
97+
elif hasattr(data, "__dataclass_fields__"):
98+
new_data_class = data_type(
10499
**{
105100
field.name: copy_data_to_device(
106101
getattr(data, field.name), device, *args, **kwargs
@@ -118,13 +113,11 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
118113
getattr(data, field.name), device, *args, **kwargs
119114
),
120115
)
121-
# pyre-fixme[7]: Expected `T` but got `DataclassInstance`.
122116
return new_data_class
123-
elif isinstance(data, _CopyableData):
124-
# pyre-fixme[7]: Expected `T` but got `_CopyableData`.
117+
elif hasattr(data, "to"):
118+
# pyre-ignore Undefined attribute [16]: `Variable[T]` has no attribute `to`
125119
return data.to(device, *args, **kwargs)
126-
# pyre-fixme[7]: Expected `T` but got `Union[Type[DataclassInstance],
127-
# DataclassInstance]`.
120+
128121
return data
129122

130123

0 commit comments

Comments
 (0)