|
4 | 4 | import pickle
|
5 | 5 | from collections.abc import Sequence
|
6 | 6 | from inspect import isclass
|
7 |
| -from itertools import chain |
8 | 7 | from types import FunctionType
|
9 | 8 | from typing import Any, Optional, Union
|
10 | 9 |
|
@@ -84,22 +83,24 @@ def enc_hook(self, obj: Any) -> Any:
|
84 | 83 | if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
|
85 | 84 | return self._encode_ndarray(obj)
|
86 | 85 |
|
87 |
| - if isinstance(obj, MultiModalKwargs): |
88 |
| - mm: MultiModalKwargs = obj |
89 |
| - if not mm.modalities: |
90 |
| - # just return the main dict if there are no modalities. |
91 |
| - return dict(mm) |
92 |
| - |
93 |
| - # ignore the main dict, it will be re-indexed. |
94 |
| - # Encode a list of MultiModalKwargsItems as plain dicts |
95 |
| - # + special handling for .field. |
96 |
| - # Any tensors *not* indexed by modality will be ignored. |
97 |
| - return [{ |
98 |
| - "modality": elem.modality, |
99 |
| - "key": elem.key, |
100 |
| - "data": self._encode_nested_tensors(elem.data), |
101 |
| - "field": self._encode_field(elem.field), |
102 |
| - } for item in mm._items_by_modality.values() for elem in item] |
| 86 | + if isinstance(obj, MultiModalKwargs): |
| 87 | + mm: MultiModalKwargs = obj |
| 88 | + if not mm.modalities: |
| 89 | + # just return the main dict if there are no modalities. |
| 90 | + return dict(mm) |
| 91 | + |
| 92 | + # ignore the main dict, it will be re-indexed. |
| 93 | + # Encode a list of MultiModalKwargsItems as plain dicts |
| 94 | + # + special handling for .field. |
| 95 | + # Any tensors *not* indexed by modality will be ignored. |
| 96 | + return [[{ |
| 97 | + "modality": elem.modality, |
| 98 | + "key": elem.key, |
| 99 | + "data": self._encode_nested_tensors(elem.data), |
| 100 | + "field": self._encode_field(elem.field), |
| 101 | + } for elem in item.values()] |
| 102 | + for itemlist in mm._items_by_modality.values() |
| 103 | + for item in itemlist] |
103 | 104 |
|
104 | 105 | if isinstance(obj, FunctionType):
|
105 | 106 | # `pickle` is generally faster than cloudpickle, but can have
|
@@ -199,7 +200,7 @@ def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
199 | 200 |
|
200 | 201 | def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
|
201 | 202 | all = []
|
202 |
| - for item in chain.from_iterable(obj): |
| 203 | + for item in obj: |
203 | 204 | elems = []
|
204 | 205 | for v in item:
|
205 | 206 | v["data"] = self._decode_nested_tensors(v["data"])
|
|
0 commit comments