File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -91,8 +91,12 @@ def enc_hook(self, obj: Any) -> Any:
91
91
ret = []
92
92
for elem in obj .values ():
93
93
# Encode as plain dictionary + special handling for .field
94
- ret .append (
95
- asdict (elem ) | {"field" : self ._encode_field (elem .field )})
94
+ ret .append ({
95
+ "modality" : elem .modality ,
96
+ "key" : elem .key ,
97
+ "data" : self ._encode_nested_tensors (elem .data ),
98
+ "field" : self ._encode_field (elem .field ),
99
+ })
96
100
return ret
97
101
98
102
if isinstance (obj , FunctionType ):
@@ -126,6 +130,11 @@ def _encode_ndarray(
126
130
# backing buffers that we've stashed in `aux_buffers`.
127
131
return obj .dtype .str , obj .shape , data
128
132
133
+ def _encode_nested_tensors (self , obj : Any ) -> NestedTensors :
134
+ if isinstance (obj , torch .Tensor ):
135
+ return self ._encode_ndarray (obj .numpy ())
136
+ return [self ._encode_nested_tensors (x ) for x in obj ]
137
+
129
138
def _encode_field (self , field : BaseMultiModalField ):
130
139
# Encode the field as a dictionary + special handling for .field
131
140
d = asdict (field )
You can’t perform that action at this time.
0 commit comments