@@ -52,13 +52,6 @@ def get_device_from_env() -> torch.device:
52
52
TSelf = TypeVar ("TSelf" )
53
53
54
54
55
- @runtime_checkable
56
- class _CopyableData (Protocol ):
57
- def to (self : TSelf , device : torch .device , * args : Any , ** kwargs : Any ) -> TSelf :
58
- """Copy data to the specified device"""
59
- ...
60
-
61
-
62
55
def _is_named_tuple (x : T ) -> bool :
63
56
return isinstance (x , tuple ) and hasattr (x , "_asdict" ) and hasattr (x , "_fields" )
64
57
@@ -76,31 +69,33 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
76
69
The data on the correct device
77
70
"""
78
71
79
- # Redundant isinstance(data, tuple) check is required here to make pyre happy
80
- if _is_named_tuple (data ) and isinstance (data , tuple ):
81
- return type (data )(
82
- ** copy_data_to_device (data ._asdict (), device , * args , ** kwargs )
83
- )
84
- elif isinstance (data , (list , tuple )):
85
- return type (data )(copy_data_to_device (e , device , * args , ** kwargs ) for e in data )
86
- elif isinstance (data , defaultdict ):
87
- return type (data )(
72
+ data_type = type (data )
73
+ if issubclass (data_type , defaultdict ):
74
+ return data_type (
88
75
data .default_factory ,
89
76
{
90
77
k : copy_data_to_device (v , device , * args , ** kwargs )
91
78
for k , v in data .items ()
92
79
},
93
80
)
94
- elif isinstance ( data , Mapping ):
95
- return type ( data ) (
81
+ elif issubclass ( data_type , dict ):
82
+ return data_type (
96
83
{
97
84
k : copy_data_to_device (v , device , * args , ** kwargs )
98
85
for k , v in data .items ()
99
86
}
100
87
)
101
- elif is_dataclass (data ) and not isinstance (data , type ):
102
- # pyre-fixme[45]: Cannot instantiate protocol `DataclassInstance`.
103
- new_data_class = type (data )(
88
+ elif issubclass (data_type , list ):
89
+ return data_type (copy_data_to_device (e , device , * args , ** kwargs ) for e in data )
90
+ elif issubclass (data_type , tuple ):
91
+ if hasattr (data , "_asdict" ) and hasattr (data , "_fields" ):
92
+ return data_type (
93
+ ** copy_data_to_device (data ._asdict (), device , * args , ** kwargs )
94
+ )
95
+ return data_type (copy_data_to_device (e , device , * args , ** kwargs ) for e in data )
96
+ # checking for __dataclass_fields__ is official way to check if data is a dataclass
97
+ elif hasattr (data , "__dataclass_fields__" ):
98
+ new_data_class = data_type (
104
99
** {
105
100
field .name : copy_data_to_device (
106
101
getattr (data , field .name ), device , * args , ** kwargs
@@ -118,13 +113,11 @@ def copy_data_to_device(data: T, device: torch.device, *args: Any, **kwargs: Any
118
113
getattr (data , field .name ), device , * args , ** kwargs
119
114
),
120
115
)
121
- # pyre-fixme[7]: Expected `T` but got `DataclassInstance`.
122
116
return new_data_class
123
- elif isinstance (data , _CopyableData ):
124
- # pyre-fixme[7 ]: Expected `T` but got `_CopyableData`.
117
+ elif hasattr (data , "to" ):
118
+ # pyre-ignore Undefined attribute [16 ]: `Variable[T]` has no attribute `to`
125
119
return data .to (device , * args , ** kwargs )
126
- # pyre-fixme[7]: Expected `T` but got `Union[Type[DataclassInstance],
127
- # DataclassInstance]`.
120
+
128
121
return data
129
122
130
123
0 commit comments