Skip to content

Commit e818ae5

Browse files
authored
Merge branch 'main' into main
2 parents 8989f9c + 4d6f9d4 commit e818ae5

File tree

14 files changed

+185
-159
lines changed

14 files changed

+185
-159
lines changed
Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
import contextlib
43
from collections.abc import Callable
54
from typing import Any, Protocol, Union, runtime_checkable
65

7-
from datasets import Dataset, IterableDataset
6+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
87
from transformers import PreTrainedTokenizerBase
98

109
from guidellm.data.utils import resolve_dataset_split
@@ -29,7 +28,7 @@ def __call__(
2928
processor_factory: Callable[[], PreTrainedTokenizerBase],
3029
random_seed: int,
3130
**data_kwargs: dict[str, Any],
32-
) -> dict[str, list]: ...
31+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
3332

3433

3534
class DatasetDeserializerFactory(
@@ -47,51 +46,16 @@ def deserialize(
4746
remove_columns: list[str] | None = None,
4847
**data_kwargs: dict[str, Any],
4948
) -> Dataset | IterableDataset:
50-
dataset = None
49+
dataset: Dataset
5150

5251
if type_ is None:
53-
errors = []
54-
# Note: There is no priority order for the deserializers, so all deserializers
55-
# must be mutually exclusive to ensure deterministic behavior.
56-
for name, deserializer in cls.registry.items():
57-
deserializer_fn: DatasetDeserializer = (
58-
deserializer() if isinstance(deserializer, type) else deserializer
59-
)
60-
61-
try:
62-
with contextlib.suppress(DataNotSupportedError):
63-
dataset = deserializer_fn(
64-
data=data,
65-
processor_factory=processor_factory,
66-
random_seed=random_seed,
67-
**data_kwargs,
68-
)
69-
except Exception as e:
70-
errors.append(e)
71-
72-
if dataset is not None:
73-
break # Found one that works. Continuing could overwrite it.
74-
75-
if dataset is None and len(errors) > 0:
76-
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77-
f"attempting to deserialize data {data}: {errors}")
78-
79-
elif deserializer := cls.get_registered_object(type_) is not None:
80-
deserializer_fn: DatasetDeserializer = (
81-
deserializer() if isinstance(deserializer, type) else deserializer
52+
dataset = cls._deserialize_with_registered_deserializers(
53+
data, processor_factory, random_seed, **data_kwargs
8254
)
8355

84-
dataset = deserializer_fn(
85-
data=data,
86-
processor_factory=processor_factory,
87-
random_seed=random_seed,
88-
**data_kwargs,
89-
)
90-
91-
if dataset is None:
92-
raise DataNotSupportedError(
93-
f"No suitable deserializer found for data {data} "
94-
f"with kwargs {data_kwargs} and deserializer type {type_}."
56+
else:
57+
dataset = cls._deserialize_with_specified_deserializer(
58+
data, type_, processor_factory, random_seed, **data_kwargs
9559
)
9660

9761
if resolve_split:
@@ -107,3 +71,74 @@ def deserialize(
10771
dataset = dataset.remove_columns(remove_columns)
10872

10973
return dataset
74+
75+
@classmethod
76+
def _deserialize_with_registered_deserializers(
77+
cls,
78+
data: Any,
79+
processor_factory: Callable[[], PreTrainedTokenizerBase],
80+
random_seed: int = 42,
81+
**data_kwargs: dict[str, Any],
82+
) -> Dataset:
83+
if cls.registry is None:
84+
raise RuntimeError("registry is None; cannot deserialize dataset")
85+
dataset: Dataset | None = None
86+
87+
errors: dict[str, Exception] = {}
88+
# Note: There is no priority order for the deserializers, so all deserializers
89+
# must be mutually exclusive to ensure deterministic behavior.
90+
for _name, deserializer in cls.registry.items():
91+
deserializer_fn: DatasetDeserializer = (
92+
deserializer() if isinstance(deserializer, type) else deserializer
93+
)
94+
95+
try:
96+
dataset = deserializer_fn(
97+
data=data,
98+
processor_factory=processor_factory,
99+
random_seed=random_seed,
100+
**data_kwargs,
101+
)
102+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
103+
errors[_name] = e
104+
105+
if dataset is not None:
106+
return dataset # Success
107+
108+
if len(errors) > 0:
109+
err_msgs = ""
110+
def sort_key(item):
111+
return (isinstance(item[1], DataNotSupportedError), item[0])
112+
for key, err in sorted(errors.items(), key=sort_key):
113+
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
114+
raise ValueError(
115+
"Data deserialization failed, likely because the input doesn't "
116+
f"match any of the input formats. See the {len(errors)} error(s) that "
117+
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
118+
)
119+
return dataset
120+
121+
@classmethod
122+
def _deserialize_with_specified_deserializer(
123+
cls,
124+
data: Any,
125+
type_: str,
126+
processor_factory: Callable[[], PreTrainedTokenizerBase],
127+
random_seed: int = 42,
128+
**data_kwargs: dict[str, Any],
129+
) -> Dataset:
130+
deserializer_from_type = cls.get_registered_object(type_)
131+
if deserializer_from_type is None:
132+
raise ValueError(f"Deserializer type '{type_}' is not registered.")
133+
if isinstance(deserializer_from_type, type):
134+
deserializer_fn = deserializer_from_type()
135+
else:
136+
deserializer_fn = deserializer_from_type
137+
138+
return deserializer_fn(
139+
data=data,
140+
processor_factory=processor_factory,
141+
random_seed=random_seed,
142+
**data_kwargs,
143+
)
144+

src/guidellm/data/deserializers/file.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __call__(
3434
processor_factory: Callable[[], PreTrainedTokenizerBase],
3535
random_seed: int,
3636
**data_kwargs: dict[str, Any],
37-
) -> dict[str, list]:
37+
) -> Dataset:
3838
_ = (processor_factory, random_seed) # Ignore unused args format errors
3939

4040
if (
41-
not isinstance(data, (str, Path))
41+
not isinstance(data, str | Path)
4242
or not (path := Path(data)).exists()
4343
or not path.is_file()
4444
or path.suffix.lower() not in {".txt", ".text"}
@@ -62,10 +62,10 @@ def __call__(
6262
processor_factory: Callable[[], PreTrainedTokenizerBase],
6363
random_seed: int,
6464
**data_kwargs: dict[str, Any],
65-
) -> dict[str, list]:
65+
) -> Dataset:
6666
_ = (processor_factory, random_seed)
6767
if (
68-
not isinstance(data, (str, Path))
68+
not isinstance(data, str | Path)
6969
or not (path := Path(data)).exists()
7070
or not path.is_file()
7171
or path.suffix.lower() != ".csv"
@@ -86,10 +86,10 @@ def __call__(
8686
processor_factory: Callable[[], PreTrainedTokenizerBase],
8787
random_seed: int,
8888
**data_kwargs: dict[str, Any],
89-
) -> dict[str, list]:
89+
) -> Dataset:
9090
_ = (processor_factory, random_seed)
9191
if (
92-
not isinstance(data, (str, Path))
92+
not isinstance(data, str | Path)
9393
or not (path := Path(data)).exists()
9494
or not path.is_file()
9595
or path.suffix.lower() not in {".json", ".jsonl"}
@@ -110,10 +110,10 @@ def __call__(
110110
processor_factory: Callable[[], PreTrainedTokenizerBase],
111111
random_seed: int,
112112
**data_kwargs: dict[str, Any],
113-
) -> dict[str, list]:
113+
) -> Dataset:
114114
_ = (processor_factory, random_seed)
115115
if (
116-
not isinstance(data, (str, Path))
116+
not isinstance(data, str | Path)
117117
or not (path := Path(data)).exists()
118118
or not path.is_file()
119119
or path.suffix.lower() != ".parquet"
@@ -134,10 +134,10 @@ def __call__(
134134
processor_factory: Callable[[], PreTrainedTokenizerBase],
135135
random_seed: int,
136136
**data_kwargs: dict[str, Any],
137-
) -> dict[str, list]:
137+
) -> Dataset:
138138
_ = (processor_factory, random_seed)
139139
if (
140-
not isinstance(data, (str, Path))
140+
not isinstance(data, str | Path)
141141
or not (path := Path(data)).exists()
142142
or not path.is_file()
143143
or path.suffix.lower() != ".arrow"
@@ -158,10 +158,10 @@ def __call__(
158158
processor_factory: Callable[[], PreTrainedTokenizerBase],
159159
random_seed: int,
160160
**data_kwargs: dict[str, Any],
161-
) -> dict[str, list]:
161+
) -> Dataset:
162162
_ = (processor_factory, random_seed)
163163
if (
164-
not isinstance(data, (str, Path))
164+
not isinstance(data, str | Path)
165165
or not (path := Path(data)).exists()
166166
or not path.is_file()
167167
or path.suffix.lower() not in {".hdf5", ".h5"}
@@ -185,7 +185,7 @@ def __call__(
185185
) -> dict[str, list]:
186186
_ = (processor_factory, random_seed)
187187
if (
188-
not isinstance(data, (str, Path))
188+
not isinstance(data, str | Path)
189189
or not (path := Path(data)).exists()
190190
or not path.is_file()
191191
or path.suffix.lower() != ".db"
@@ -209,7 +209,7 @@ def __call__(
209209
) -> dict[str, list]:
210210
_ = (processor_factory, random_seed)
211211
if (
212-
not isinstance(data, (str, Path))
212+
not isinstance(data, str | Path)
213213
or not (path := Path(data)).exists()
214214
or not path.is_file()
215215
or path.suffix.lower() != ".tar"

src/guidellm/data/deserializers/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __call__(
3636
processor_factory: Callable[[], PreTrainedTokenizerBase],
3737
random_seed: int,
3838
**data_kwargs: dict[str, Any],
39-
) -> dict[str, list]:
39+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
4040
_ = (processor_factory, random_seed)
4141

4242
if isinstance(

src/guidellm/data/deserializers/memory.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __call__(
3333
processor_factory: Callable[[], PreTrainedTokenizerBase],
3434
random_seed: int,
3535
**data_kwargs: dict[str, Any],
36-
) -> dict[str, list]:
36+
) -> Dataset:
3737
_ = (processor_factory, random_seed) # Ignore unused args format errors
3838

3939
if (
@@ -67,7 +67,7 @@ def __call__(
6767
processor_factory: Callable[[], PreTrainedTokenizerBase],
6868
random_seed: int,
6969
**data_kwargs: dict[str, Any],
70-
) -> dict[str, list]:
70+
) -> Dataset:
7171
_ = (processor_factory, random_seed) # Ignore unused args format errors
7272

7373
if (
@@ -81,9 +81,9 @@ def __call__(
8181
f"expected list of dicts, got {data}"
8282
)
8383

84-
data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85-
first_keys = set(data[0].keys())
86-
for index, item in enumerate(data):
84+
typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85+
first_keys = set(typed_data[0].keys())
86+
for index, item in enumerate(typed_data):
8787
if set(item.keys()) != first_keys:
8888
raise DataNotSupportedError(
8989
f"All dictionaries must have the same keys. "
@@ -92,8 +92,8 @@ def __call__(
9292
)
9393

9494
# Convert list of dicts to dict of lists
95-
result_dict = {key: [] for key in first_keys}
96-
for item in data:
95+
result_dict: dict = {key: [] for key in first_keys}
96+
for item in typed_data:
9797
for key, value in item.items():
9898
result_dict[key].append(value)
9999

@@ -108,7 +108,7 @@ def __call__(
108108
processor_factory: Callable[[], PreTrainedTokenizerBase],
109109
random_seed: int,
110110
**data_kwargs: dict[str, Any],
111-
) -> dict[str, list]:
111+
) -> Dataset:
112112
_ = (processor_factory, random_seed) # Ignore unused args format errors
113113

114114
primitive_types = (str, int, float, bool, type(None))
@@ -135,7 +135,7 @@ def __call__(
135135
processor_factory: Callable[[], PreTrainedTokenizerBase],
136136
random_seed: int,
137137
**data_kwargs: dict[str, Any],
138-
) -> dict[str, list]:
138+
) -> Dataset:
139139
if (
140140
isinstance(data, str)
141141
and (json_str := data.strip())
@@ -145,16 +145,18 @@ def __call__(
145145
)
146146
):
147147
with contextlib.suppress(Exception):
148-
parsed = json.loads(data)
148+
parsed_data = json.loads(data)
149149

150-
for deserializer in [
151-
InMemoryDictDatasetDeserializer,
152-
InMemoryDictListDatasetDeserializer,
153-
InMemoryItemListDatasetDeserializer,
154-
]:
150+
deserializers = [
151+
InMemoryDictDatasetDeserializer(),
152+
InMemoryDictListDatasetDeserializer(),
153+
InMemoryItemListDatasetDeserializer(),
154+
]
155+
156+
for deserializer in deserializers:
155157
with contextlib.suppress(DataNotSupportedError):
156-
return deserializer()(
157-
parsed, data_kwargs, processor_factory, random_seed
158+
return deserializer(
159+
parsed_data, processor_factory, random_seed, **data_kwargs
158160
)
159161

160162
raise DataNotSupportedError(
@@ -171,7 +173,7 @@ def __call__(
171173
processor_factory: Callable[[], PreTrainedTokenizerBase],
172174
random_seed: int,
173175
**data_kwargs: dict[str, Any],
174-
) -> dict[str, list]:
176+
) -> Dataset:
175177
if (
176178
isinstance(data, str)
177179
and (csv_str := data.strip())

0 commit comments

Comments
 (0)