Skip to content

Commit 4d6f9d4

Browse files
authored
[Refactor Followup] Fix data package type and quality errors (#430)
## Summary This PR fixes type and quality errors in the data package. This brings the total type error count down to 124. Some assumptions had to be made, so it would be a good idea to look for accidental changes in logic during your reviews. ## Testing Run benchmarks with various data input types to ensure they all work. --- - [x] "I certify that all code in this PR is my own, except as noted below." ## Use of AI - [x] Includes AI-assisted code completion - [ ] Includes code generated by an AI application - [ ] Includes AI-generated tests (NOTE: AI written tests should have a docstring that includes `## WRITTEN BY AI ##`)
2 parents 2cbb7c9 + a5f60b0 commit 4d6f9d4

File tree

14 files changed

+185
-159
lines changed

14 files changed

+185
-159
lines changed
Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from __future__ import annotations
22

3-
import contextlib
43
from collections.abc import Callable
54
from typing import Any, Protocol, Union, runtime_checkable
65

7-
from datasets import Dataset, IterableDataset
6+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
87
from transformers import PreTrainedTokenizerBase
98

109
from guidellm.data.utils import resolve_dataset_split
@@ -29,7 +28,7 @@ def __call__(
2928
processor_factory: Callable[[], PreTrainedTokenizerBase],
3029
random_seed: int,
3130
**data_kwargs: dict[str, Any],
32-
) -> dict[str, list]: ...
31+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
3332

3433

3534
class DatasetDeserializerFactory(
@@ -47,51 +46,16 @@ def deserialize(
4746
remove_columns: list[str] | None = None,
4847
**data_kwargs: dict[str, Any],
4948
) -> Dataset | IterableDataset:
50-
dataset = None
49+
dataset: Dataset
5150

5251
if type_ is None:
53-
errors = []
54-
# Note: There is no priority order for the deserializers, so all deserializers
55-
# must be mutually exclusive to ensure deterministic behavior.
56-
for name, deserializer in cls.registry.items():
57-
deserializer_fn: DatasetDeserializer = (
58-
deserializer() if isinstance(deserializer, type) else deserializer
59-
)
60-
61-
try:
62-
with contextlib.suppress(DataNotSupportedError):
63-
dataset = deserializer_fn(
64-
data=data,
65-
processor_factory=processor_factory,
66-
random_seed=random_seed,
67-
**data_kwargs,
68-
)
69-
except Exception as e:
70-
errors.append(e)
71-
72-
if dataset is not None:
73-
break # Found one that works. Continuing could overwrite it.
74-
75-
if dataset is None and len(errors) > 0:
76-
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77-
f"attempting to deserialize data {data}: {errors}")
78-
79-
elif deserializer := cls.get_registered_object(type_) is not None:
80-
deserializer_fn: DatasetDeserializer = (
81-
deserializer() if isinstance(deserializer, type) else deserializer
52+
dataset = cls._deserialize_with_registered_deserializers(
53+
data, processor_factory, random_seed, **data_kwargs
8254
)
8355

84-
dataset = deserializer_fn(
85-
data=data,
86-
processor_factory=processor_factory,
87-
random_seed=random_seed,
88-
**data_kwargs,
89-
)
90-
91-
if dataset is None:
92-
raise DataNotSupportedError(
93-
f"No suitable deserializer found for data {data} "
94-
f"with kwargs {data_kwargs} and deserializer type {type_}."
56+
else:
57+
dataset = cls._deserialize_with_specified_deserializer(
58+
data, type_, processor_factory, random_seed, **data_kwargs
9559
)
9660

9761
if resolve_split:
@@ -107,3 +71,74 @@ def deserialize(
10771
dataset = dataset.remove_columns(remove_columns)
10872

10973
return dataset
74+
75+
@classmethod
76+
def _deserialize_with_registered_deserializers(
77+
cls,
78+
data: Any,
79+
processor_factory: Callable[[], PreTrainedTokenizerBase],
80+
random_seed: int = 42,
81+
**data_kwargs: dict[str, Any],
82+
) -> Dataset:
83+
if cls.registry is None:
84+
raise RuntimeError("registry is None; cannot deserialize dataset")
85+
dataset: Dataset | None = None
86+
87+
errors: dict[str, Exception] = {}
88+
# Note: There is no priority order for the deserializers, so all deserializers
89+
# must be mutually exclusive to ensure deterministic behavior.
90+
for _name, deserializer in cls.registry.items():
91+
deserializer_fn: DatasetDeserializer = (
92+
deserializer() if isinstance(deserializer, type) else deserializer
93+
)
94+
95+
try:
96+
dataset = deserializer_fn(
97+
data=data,
98+
processor_factory=processor_factory,
99+
random_seed=random_seed,
100+
**data_kwargs,
101+
)
102+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
103+
errors[_name] = e
104+
105+
if dataset is not None:
106+
return dataset # Success
107+
108+
if len(errors) > 0:
109+
err_msgs = ""
110+
def sort_key(item):
111+
return (isinstance(item[1], DataNotSupportedError), item[0])
112+
for key, err in sorted(errors.items(), key=sort_key):
113+
err_msgs += f"\n - Deserializer '{key}': ({type(err).__name__}) {err}"
114+
raise ValueError(
115+
"Data deserialization failed, likely because the input doesn't "
116+
f"match any of the input formats. See the {len(errors)} error(s) that "
117+
f"occurred while attempting to deserialize the data {data}:{err_msgs}"
118+
)
119+
return dataset
120+
121+
@classmethod
122+
def _deserialize_with_specified_deserializer(
123+
cls,
124+
data: Any,
125+
type_: str,
126+
processor_factory: Callable[[], PreTrainedTokenizerBase],
127+
random_seed: int = 42,
128+
**data_kwargs: dict[str, Any],
129+
) -> Dataset:
130+
deserializer_from_type = cls.get_registered_object(type_)
131+
if deserializer_from_type is None:
132+
raise ValueError(f"Deserializer type '{type_}' is not registered.")
133+
if isinstance(deserializer_from_type, type):
134+
deserializer_fn = deserializer_from_type()
135+
else:
136+
deserializer_fn = deserializer_from_type
137+
138+
return deserializer_fn(
139+
data=data,
140+
processor_factory=processor_factory,
141+
random_seed=random_seed,
142+
**data_kwargs,
143+
)
144+

src/guidellm/data/deserializers/file.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __call__(
3434
processor_factory: Callable[[], PreTrainedTokenizerBase],
3535
random_seed: int,
3636
**data_kwargs: dict[str, Any],
37-
) -> dict[str, list]:
37+
) -> Dataset:
3838
_ = (processor_factory, random_seed) # Ignore unused args format errors
3939

4040
if (
41-
not isinstance(data, (str, Path))
41+
not isinstance(data, str | Path)
4242
or not (path := Path(data)).exists()
4343
or not path.is_file()
4444
or path.suffix.lower() not in {".txt", ".text"}
@@ -62,10 +62,10 @@ def __call__(
6262
processor_factory: Callable[[], PreTrainedTokenizerBase],
6363
random_seed: int,
6464
**data_kwargs: dict[str, Any],
65-
) -> dict[str, list]:
65+
) -> Dataset:
6666
_ = (processor_factory, random_seed)
6767
if (
68-
not isinstance(data, (str, Path))
68+
not isinstance(data, str | Path)
6969
or not (path := Path(data)).exists()
7070
or not path.is_file()
7171
or path.suffix.lower() != ".csv"
@@ -86,10 +86,10 @@ def __call__(
8686
processor_factory: Callable[[], PreTrainedTokenizerBase],
8787
random_seed: int,
8888
**data_kwargs: dict[str, Any],
89-
) -> dict[str, list]:
89+
) -> Dataset:
9090
_ = (processor_factory, random_seed)
9191
if (
92-
not isinstance(data, (str, Path))
92+
not isinstance(data, str | Path)
9393
or not (path := Path(data)).exists()
9494
or not path.is_file()
9595
or path.suffix.lower() not in {".json", ".jsonl"}
@@ -110,10 +110,10 @@ def __call__(
110110
processor_factory: Callable[[], PreTrainedTokenizerBase],
111111
random_seed: int,
112112
**data_kwargs: dict[str, Any],
113-
) -> dict[str, list]:
113+
) -> Dataset:
114114
_ = (processor_factory, random_seed)
115115
if (
116-
not isinstance(data, (str, Path))
116+
not isinstance(data, str | Path)
117117
or not (path := Path(data)).exists()
118118
or not path.is_file()
119119
or path.suffix.lower() != ".parquet"
@@ -134,10 +134,10 @@ def __call__(
134134
processor_factory: Callable[[], PreTrainedTokenizerBase],
135135
random_seed: int,
136136
**data_kwargs: dict[str, Any],
137-
) -> dict[str, list]:
137+
) -> Dataset:
138138
_ = (processor_factory, random_seed)
139139
if (
140-
not isinstance(data, (str, Path))
140+
not isinstance(data, str | Path)
141141
or not (path := Path(data)).exists()
142142
or not path.is_file()
143143
or path.suffix.lower() != ".arrow"
@@ -158,10 +158,10 @@ def __call__(
158158
processor_factory: Callable[[], PreTrainedTokenizerBase],
159159
random_seed: int,
160160
**data_kwargs: dict[str, Any],
161-
) -> dict[str, list]:
161+
) -> Dataset:
162162
_ = (processor_factory, random_seed)
163163
if (
164-
not isinstance(data, (str, Path))
164+
not isinstance(data, str | Path)
165165
or not (path := Path(data)).exists()
166166
or not path.is_file()
167167
or path.suffix.lower() not in {".hdf5", ".h5"}
@@ -185,7 +185,7 @@ def __call__(
185185
) -> dict[str, list]:
186186
_ = (processor_factory, random_seed)
187187
if (
188-
not isinstance(data, (str, Path))
188+
not isinstance(data, str | Path)
189189
or not (path := Path(data)).exists()
190190
or not path.is_file()
191191
or path.suffix.lower() != ".db"
@@ -209,7 +209,7 @@ def __call__(
209209
) -> dict[str, list]:
210210
_ = (processor_factory, random_seed)
211211
if (
212-
not isinstance(data, (str, Path))
212+
not isinstance(data, str | Path)
213213
or not (path := Path(data)).exists()
214214
or not path.is_file()
215215
or path.suffix.lower() != ".tar"

src/guidellm/data/deserializers/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __call__(
3636
processor_factory: Callable[[], PreTrainedTokenizerBase],
3737
random_seed: int,
3838
**data_kwargs: dict[str, Any],
39-
) -> dict[str, list]:
39+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
4040
_ = (processor_factory, random_seed)
4141

4242
if isinstance(

src/guidellm/data/deserializers/memory.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __call__(
3333
processor_factory: Callable[[], PreTrainedTokenizerBase],
3434
random_seed: int,
3535
**data_kwargs: dict[str, Any],
36-
) -> dict[str, list]:
36+
) -> Dataset:
3737
_ = (processor_factory, random_seed) # Ignore unused args format errors
3838

3939
if (
@@ -67,7 +67,7 @@ def __call__(
6767
processor_factory: Callable[[], PreTrainedTokenizerBase],
6868
random_seed: int,
6969
**data_kwargs: dict[str, Any],
70-
) -> dict[str, list]:
70+
) -> Dataset:
7171
_ = (processor_factory, random_seed) # Ignore unused args format errors
7272

7373
if (
@@ -81,9 +81,9 @@ def __call__(
8181
f"expected list of dicts, got {data}"
8282
)
8383

84-
data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85-
first_keys = set(data[0].keys())
86-
for index, item in enumerate(data):
84+
typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85+
first_keys = set(typed_data[0].keys())
86+
for index, item in enumerate(typed_data):
8787
if set(item.keys()) != first_keys:
8888
raise DataNotSupportedError(
8989
f"All dictionaries must have the same keys. "
@@ -92,8 +92,8 @@ def __call__(
9292
)
9393

9494
# Convert list of dicts to dict of lists
95-
result_dict = {key: [] for key in first_keys}
96-
for item in data:
95+
result_dict: dict = {key: [] for key in first_keys}
96+
for item in typed_data:
9797
for key, value in item.items():
9898
result_dict[key].append(value)
9999

@@ -108,7 +108,7 @@ def __call__(
108108
processor_factory: Callable[[], PreTrainedTokenizerBase],
109109
random_seed: int,
110110
**data_kwargs: dict[str, Any],
111-
) -> dict[str, list]:
111+
) -> Dataset:
112112
_ = (processor_factory, random_seed) # Ignore unused args format errors
113113

114114
primitive_types = (str, int, float, bool, type(None))
@@ -135,7 +135,7 @@ def __call__(
135135
processor_factory: Callable[[], PreTrainedTokenizerBase],
136136
random_seed: int,
137137
**data_kwargs: dict[str, Any],
138-
) -> dict[str, list]:
138+
) -> Dataset:
139139
if (
140140
isinstance(data, str)
141141
and (json_str := data.strip())
@@ -145,16 +145,18 @@ def __call__(
145145
)
146146
):
147147
with contextlib.suppress(Exception):
148-
parsed = json.loads(data)
148+
parsed_data = json.loads(data)
149149

150-
for deserializer in [
151-
InMemoryDictDatasetDeserializer,
152-
InMemoryDictListDatasetDeserializer,
153-
InMemoryItemListDatasetDeserializer,
154-
]:
150+
deserializers = [
151+
InMemoryDictDatasetDeserializer(),
152+
InMemoryDictListDatasetDeserializer(),
153+
InMemoryItemListDatasetDeserializer(),
154+
]
155+
156+
for deserializer in deserializers:
155157
with contextlib.suppress(DataNotSupportedError):
156-
return deserializer()(
157-
parsed, data_kwargs, processor_factory, random_seed
158+
return deserializer(
159+
parsed_data, processor_factory, random_seed, **data_kwargs
158160
)
159161

160162
raise DataNotSupportedError(
@@ -171,7 +173,7 @@ def __call__(
171173
processor_factory: Callable[[], PreTrainedTokenizerBase],
172174
random_seed: int,
173175
**data_kwargs: dict[str, Any],
174-
) -> dict[str, list]:
176+
) -> Dataset:
175177
if (
176178
isinstance(data, str)
177179
and (csv_str := data.strip())

0 commit comments

Comments
 (0)