Skip to content

Commit 5c06a8a

Browse files
committed
Fix data package type and quality errors
Signed-off-by: Jared O'Connell <[email protected]>
1 parent 2cbb7c9 commit 5c06a8a

File tree

13 files changed

+143
-101
lines changed

13 files changed

+143
-101
lines changed

src/guidellm/data/deserializers/deserializer.py

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Callable
55
from typing import Any, Protocol, Union, runtime_checkable
66

7-
from datasets import Dataset, IterableDataset
7+
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
88
from transformers import PreTrainedTokenizerBase
99

1010
from guidellm.data.utils import resolve_dataset_split
@@ -29,7 +29,7 @@ def __call__(
2929
processor_factory: Callable[[], PreTrainedTokenizerBase],
3030
random_seed: int,
3131
**data_kwargs: dict[str, Any],
32-
) -> dict[str, list]: ...
32+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict: ...
3333

3434

3535
class DatasetDeserializerFactory(
@@ -47,40 +47,19 @@ def deserialize(
4747
remove_columns: list[str] | None = None,
4848
**data_kwargs: dict[str, Any],
4949
) -> Dataset | IterableDataset:
50-
dataset = None
50+
dataset: Dataset | None = None
5151

5252
if type_ is None:
53-
errors = []
54-
# Note: There is no priority order for the deserializers, so all deserializers
55-
# must be mutually exclusive to ensure deterministic behavior.
56-
for name, deserializer in cls.registry.items():
57-
deserializer_fn: DatasetDeserializer = (
58-
deserializer() if isinstance(deserializer, type) else deserializer
59-
)
60-
61-
try:
62-
with contextlib.suppress(DataNotSupportedError):
63-
dataset = deserializer_fn(
64-
data=data,
65-
processor_factory=processor_factory,
66-
random_seed=random_seed,
67-
**data_kwargs,
68-
)
69-
except Exception as e:
70-
errors.append(e)
71-
72-
if dataset is not None:
73-
break # Found one that works. Continuing could overwrite it.
74-
75-
if dataset is None and len(errors) > 0:
76-
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77-
f"attempting to deserialize data {data}: {errors}")
78-
79-
elif deserializer := cls.get_registered_object(type_) is not None:
80-
deserializer_fn: DatasetDeserializer = (
81-
deserializer() if isinstance(deserializer, type) else deserializer
53+
dataset = cls._deserialize_with_registered_deserializers(
54+
data, processor_factory, random_seed, **data_kwargs
8255
)
8356

57+
elif (deserializer_from_type := cls.get_registered_object(type_)) is not None:
58+
if isinstance(deserializer_from_type, type):
59+
deserializer_fn = deserializer_from_type()
60+
else:
61+
deserializer_fn = deserializer_from_type
62+
8463
dataset = deserializer_fn(
8564
data=data,
8665
processor_factory=processor_factory,
@@ -107,3 +86,44 @@ def deserialize(
10786
dataset = dataset.remove_columns(remove_columns)
10887

10988
return dataset
89+
90+
@classmethod
91+
def _deserialize_with_registered_deserializers(
92+
cls,
93+
data: Any,
94+
processor_factory: Callable[[], PreTrainedTokenizerBase],
95+
random_seed: int = 42,
96+
**data_kwargs: dict[str, Any],
97+
) -> Dataset:
98+
if cls.registry is None:
99+
raise RuntimeError("registry is None; cannot deserialize dataset")
100+
dataset: Dataset | None = None
101+
102+
errors = []
103+
# Note: There is no priority order for the deserializers, so all deserializers
104+
# must be mutually exclusive to ensure deterministic behavior.
105+
for _name, deserializer in cls.registry.items():
106+
deserializer_fn: DatasetDeserializer = (
107+
deserializer() if isinstance(deserializer, type) else deserializer
108+
)
109+
110+
try:
111+
with contextlib.suppress(DataNotSupportedError):
112+
dataset = deserializer_fn(
113+
data=data,
114+
processor_factory=processor_factory,
115+
random_seed=random_seed,
116+
**data_kwargs,
117+
)
118+
except Exception as e: # noqa: BLE001 # The exceptions are saved.
119+
errors.append(e)
120+
121+
if dataset is not None:
122+
break # Found one that works. Continuing could overwrite it.
123+
124+
if dataset is None and len(errors) > 0:
125+
raise DataNotSupportedError(
126+
f"data deserialization failed; {len(errors)} errors occurred while "
127+
f"attempting to deserialize data {data}: {errors}"
128+
)
129+
return dataset

src/guidellm/data/deserializers/file.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,11 @@ def __call__(
3434
processor_factory: Callable[[], PreTrainedTokenizerBase],
3535
random_seed: int,
3636
**data_kwargs: dict[str, Any],
37-
) -> dict[str, list]:
37+
) -> Dataset:
3838
_ = (processor_factory, random_seed) # Ignore unused args format errors
3939

4040
if (
41-
not isinstance(data, (str, Path))
41+
not isinstance(data, str | Path)
4242
or not (path := Path(data)).exists()
4343
or not path.is_file()
4444
or path.suffix.lower() not in {".txt", ".text"}
@@ -62,10 +62,10 @@ def __call__(
6262
processor_factory: Callable[[], PreTrainedTokenizerBase],
6363
random_seed: int,
6464
**data_kwargs: dict[str, Any],
65-
) -> dict[str, list]:
65+
) -> Dataset:
6666
_ = (processor_factory, random_seed)
6767
if (
68-
not isinstance(data, (str, Path))
68+
not isinstance(data, str | Path)
6969
or not (path := Path(data)).exists()
7070
or not path.is_file()
7171
or path.suffix.lower() != ".csv"
@@ -86,10 +86,10 @@ def __call__(
8686
processor_factory: Callable[[], PreTrainedTokenizerBase],
8787
random_seed: int,
8888
**data_kwargs: dict[str, Any],
89-
) -> dict[str, list]:
89+
) -> Dataset:
9090
_ = (processor_factory, random_seed)
9191
if (
92-
not isinstance(data, (str, Path))
92+
not isinstance(data, str | Path)
9393
or not (path := Path(data)).exists()
9494
or not path.is_file()
9595
or path.suffix.lower() not in {".json", ".jsonl"}
@@ -110,10 +110,10 @@ def __call__(
110110
processor_factory: Callable[[], PreTrainedTokenizerBase],
111111
random_seed: int,
112112
**data_kwargs: dict[str, Any],
113-
) -> dict[str, list]:
113+
) -> Dataset:
114114
_ = (processor_factory, random_seed)
115115
if (
116-
not isinstance(data, (str, Path))
116+
not isinstance(data, str | Path)
117117
or not (path := Path(data)).exists()
118118
or not path.is_file()
119119
or path.suffix.lower() != ".parquet"
@@ -134,10 +134,10 @@ def __call__(
134134
processor_factory: Callable[[], PreTrainedTokenizerBase],
135135
random_seed: int,
136136
**data_kwargs: dict[str, Any],
137-
) -> dict[str, list]:
137+
) -> Dataset:
138138
_ = (processor_factory, random_seed)
139139
if (
140-
not isinstance(data, (str, Path))
140+
not isinstance(data, str | Path)
141141
or not (path := Path(data)).exists()
142142
or not path.is_file()
143143
or path.suffix.lower() != ".arrow"
@@ -158,10 +158,10 @@ def __call__(
158158
processor_factory: Callable[[], PreTrainedTokenizerBase],
159159
random_seed: int,
160160
**data_kwargs: dict[str, Any],
161-
) -> dict[str, list]:
161+
) -> Dataset:
162162
_ = (processor_factory, random_seed)
163163
if (
164-
not isinstance(data, (str, Path))
164+
not isinstance(data, str | Path)
165165
or not (path := Path(data)).exists()
166166
or not path.is_file()
167167
or path.suffix.lower() not in {".hdf5", ".h5"}
@@ -185,7 +185,7 @@ def __call__(
185185
) -> dict[str, list]:
186186
_ = (processor_factory, random_seed)
187187
if (
188-
not isinstance(data, (str, Path))
188+
not isinstance(data, str | Path)
189189
or not (path := Path(data)).exists()
190190
or not path.is_file()
191191
or path.suffix.lower() != ".db"
@@ -209,7 +209,7 @@ def __call__(
209209
) -> dict[str, list]:
210210
_ = (processor_factory, random_seed)
211211
if (
212-
not isinstance(data, (str, Path))
212+
not isinstance(data, str | Path)
213213
or not (path := Path(data)).exists()
214214
or not path.is_file()
215215
or path.suffix.lower() != ".tar"

src/guidellm/data/deserializers/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __call__(
3636
processor_factory: Callable[[], PreTrainedTokenizerBase],
3737
random_seed: int,
3838
**data_kwargs: dict[str, Any],
39-
) -> dict[str, list]:
39+
) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict:
4040
_ = (processor_factory, random_seed)
4141

4242
if isinstance(

src/guidellm/data/deserializers/memory.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __call__(
3333
processor_factory: Callable[[], PreTrainedTokenizerBase],
3434
random_seed: int,
3535
**data_kwargs: dict[str, Any],
36-
) -> dict[str, list]:
36+
) -> Dataset:
3737
_ = (processor_factory, random_seed) # Ignore unused args format errors
3838

3939
if (
@@ -67,7 +67,7 @@ def __call__(
6767
processor_factory: Callable[[], PreTrainedTokenizerBase],
6868
random_seed: int,
6969
**data_kwargs: dict[str, Any],
70-
) -> dict[str, list]:
70+
) -> Dataset:
7171
_ = (processor_factory, random_seed) # Ignore unused args format errors
7272

7373
if (
@@ -81,9 +81,9 @@ def __call__(
8181
f"expected list of dicts, got {data}"
8282
)
8383

84-
data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85-
first_keys = set(data[0].keys())
86-
for index, item in enumerate(data):
84+
typed_data: list[dict[str, Any]] = cast("list[dict[str, Any]]", data)
85+
first_keys = set(typed_data[0].keys())
86+
for index, item in enumerate(typed_data):
8787
if set(item.keys()) != first_keys:
8888
raise DataNotSupportedError(
8989
f"All dictionaries must have the same keys. "
@@ -92,8 +92,8 @@ def __call__(
9292
)
9393

9494
# Convert list of dicts to dict of lists
95-
result_dict = {key: [] for key in first_keys}
96-
for item in data:
95+
result_dict: dict = {key: [] for key in first_keys}
96+
for item in typed_data:
9797
for key, value in item.items():
9898
result_dict[key].append(value)
9999

@@ -108,7 +108,7 @@ def __call__(
108108
processor_factory: Callable[[], PreTrainedTokenizerBase],
109109
random_seed: int,
110110
**data_kwargs: dict[str, Any],
111-
) -> dict[str, list]:
111+
) -> Dataset:
112112
_ = (processor_factory, random_seed) # Ignore unused args format errors
113113

114114
primitive_types = (str, int, float, bool, type(None))
@@ -135,7 +135,7 @@ def __call__(
135135
processor_factory: Callable[[], PreTrainedTokenizerBase],
136136
random_seed: int,
137137
**data_kwargs: dict[str, Any],
138-
) -> dict[str, list]:
138+
) -> Dataset:
139139
if (
140140
isinstance(data, str)
141141
and (json_str := data.strip())
@@ -145,16 +145,18 @@ def __call__(
145145
)
146146
):
147147
with contextlib.suppress(Exception):
148-
parsed = json.loads(data)
148+
parsed_data = json.loads(data)
149149

150-
for deserializer in [
151-
InMemoryDictDatasetDeserializer,
152-
InMemoryDictListDatasetDeserializer,
153-
InMemoryItemListDatasetDeserializer,
154-
]:
150+
deserializers = [
151+
InMemoryDictDatasetDeserializer(),
152+
InMemoryDictListDatasetDeserializer(),
153+
InMemoryItemListDatasetDeserializer(),
154+
]
155+
156+
for deserializer in deserializers:
155157
with contextlib.suppress(DataNotSupportedError):
156-
return deserializer()(
157-
parsed, data_kwargs, processor_factory, random_seed
158+
return deserializer(
159+
parsed_data, processor_factory, random_seed, **data_kwargs
158160
)
159161

160162
raise DataNotSupportedError(
@@ -171,7 +173,7 @@ def __call__(
171173
processor_factory: Callable[[], PreTrainedTokenizerBase],
172174
random_seed: int,
173175
**data_kwargs: dict[str, Any],
174-
) -> dict[str, list]:
176+
) -> Dataset:
175177
if (
176178
isinstance(data, str)
177179
and (csv_str := data.strip())

src/guidellm/data/deserializers/synthetic.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,21 +99,25 @@ class SyntheticTextDatasetConfig(StandardBaseModel):
9999

100100
@model_validator(mode="after")
101101
def check_prefix_options(self) -> SyntheticTextDatasetConfig:
102-
prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
103-
prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
104-
if prefix_count is not None or prefix_tokens is not None:
105-
if self.prefix_buckets:
102+
prefix_count: Any | None = None
103+
prefix_tokens: Any | None = None
104+
if self.__pydantic_extra__ is not None:
105+
prefix_count = self.__pydantic_extra__.get("prefix_count", None) # type: ignore[attr-defined]
106+
prefix_tokens = self.__pydantic_extra__.get("prefix_tokens", None) # type: ignore[attr-defined]
107+
108+
if (prefix_count is not None or prefix_tokens is not None
109+
and self.prefix_buckets):
106110
raise ValueError(
107111
"prefix_buckets is mutually exclusive"
108112
" with prefix_count and prefix_tokens"
109113
)
110114

111-
self.prefix_buckets = [
112-
SyntheticTextPrefixBucketConfig(
113-
prefix_count=prefix_count or 1,
114-
prefix_tokens=prefix_tokens or 0,
115-
)
116-
]
115+
self.prefix_buckets = [
116+
SyntheticTextPrefixBucketConfig(
117+
prefix_count=prefix_count or 1,
118+
prefix_tokens=prefix_tokens or 0,
119+
)
120+
]
117121

118122
return self
119123

@@ -174,14 +178,14 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
174178
def _create_prompt(
175179
self, prompt_tokens_count: int, faker: Faker, unique: str = ""
176180
) -> str:
177-
prompt_token_ids = []
181+
prompt_token_ids: list[int] = []
178182
avg_chars_per_token = 5
179183
margin_of_safety = 1.5
180184
attempts = 0
181185

182186
while len(prompt_token_ids) < prompt_tokens_count:
183187
attempts += 1
184-
num_chars = (
188+
num_chars = math.ceil(
185189
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
186190
)
187191
text = unique + faker.text(max_nb_chars=num_chars)

0 commit comments

Comments
 (0)