@@ -75,9 +75,9 @@ def enc_hook(self, obj: Any) -> Any:
75
75
# ignore the main dict, it will be re-indexed.
76
76
# pass a list of MultiModalKwargsItem, then see below
77
77
# Any tensors *not* indexed by modality will be ignored.
78
- return [ mm .get_items ( m ) for m in mm . modalities ]
78
+ return mm ._items_by_modality . values ()
79
79
# just return the main dict if there are no modalities
80
- return { k : v for k , v in obj . items ()}
80
+ return dict ( mm )
81
81
82
82
if isinstance (obj , MultiModalKwargsItem ):
83
83
# Encode as plain dictionary + special handling for '.field'
@@ -146,12 +146,12 @@ def dec_hook(self, t: type, obj: Any) -> Any:
146
146
if isclass (t ):
147
147
if issubclass (t , np .ndarray ):
148
148
return self ._decode_ndarray (obj )
149
- if issubclass (t , MultiModalKwargs ) and isinstance (obj , dict ):
149
+ if issubclass (t , MultiModalKwargs )
150
+ if isinstance (obj , list ):
151
+ return MultiModalKwargs .from_items (self ._decode_items (obj ))
150
152
return MultiModalKwargs (
151
- {k : self ._decode_nested (obj [k ])
152
- for k in obj })
153
- if issubclass (t , MultiModalKwargs ) and isinstance (obj , list ):
154
- return MultiModalKwargs .from_items (self ._decode_items (obj ))
153
+ {k : self ._decode_nested (v )
154
+ for k in obj .items ()})
155
155
if issubclass (t , torch .Tensor ):
156
156
return torch .from_numpy (self ._decode_ndarray (obj ))
157
157
return obj
@@ -172,12 +172,12 @@ def _decode_items(self, obj: list) -> list[MultiModalKwargsItem]:
172
172
all .append (MultiModalKwargsItem .from_elems (elems ))
173
173
return all
174
174
175
- def _decode_nested (self , obj : Any ) -> NestedTensors :
176
- if isinstance (obj , list ) and isinstance (obj [0 ], str ):
175
+ def _decode_nested_tensors (self , obj : Any ) -> NestedTensors :
176
+ if not isinstance (obj , list ):
177
+ raise TypeError (f"Unexpected NestedTensors contents: { type (obj )} " )
178
+ if obj and isinstance (obj [0 ], str ):
177
179
return torch .from_numpy (self ._decode_ndarray (obj ))
178
- if isinstance (obj , list ):
179
- return [self ._decode_nested (x ) for x in obj ]
180
- raise TypeError (f"Unexpected NestedArray contents: { obj } " )
180
+ return [self ._decode_nested_tensors (x ) for x in obj ]
181
181
182
182
def ext_hook (self , code : int , data : memoryview ) -> Any :
183
183
if code == CUSTOM_TYPE_PICKLE :
0 commit comments