Skip to content

Commit 388b6c4

Browse files
committed
Try all data deserializers before failing
Signed-off-by: Jared O'Connell <[email protected]>
1 parent 9669983 commit 388b6c4

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

src/guidellm/data/deserializers/deserializer.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -50,31 +50,32 @@ def deserialize(
5050
dataset = None
5151

5252
if type_ is None:
53+
errors = []
54+
# Note: There is no priority order for the deserializers, so all deserializers
55+
# must be mutually exclusive to ensure deterministic behavior.
5356
for name, deserializer in cls.registry.items():
54-
if name == "huggingface":
55-
# Save Hugging Face til the end since it is a catch-all.
56-
continue
57-
5857
deserializer_fn: DatasetDeserializer = (
5958
deserializer() if isinstance(deserializer, type) else deserializer
6059
)
6160

62-
with contextlib.suppress(DataNotSupportedError):
63-
dataset = deserializer_fn(
64-
data=data,
65-
processor_factory=processor_factory,
66-
random_seed=random_seed,
67-
**data_kwargs,
68-
)
69-
70-
if dataset is None:
71-
deserializer_fn = cls.get_registered_object("huggingface")()
72-
dataset = deserializer_fn(
73-
data=data,
74-
processor_factory=processor_factory,
75-
random_seed=random_seed,
76-
**data_kwargs,
77-
)
61+
try:
62+
with contextlib.suppress(DataNotSupportedError):
63+
dataset = deserializer_fn(
64+
data=data,
65+
processor_factory=processor_factory,
66+
random_seed=random_seed,
67+
**data_kwargs,
68+
)
69+
except Exception as e:
70+
errors.append(e)
71+
72+
if dataset is not None:
73+
break # Found one that works. Continuing could overwrite it.
74+
75+
if dataset is None and len(errors) > 0:
76+
raise DataNotSupportedError(f"data deserialization failed; {len(errors)} errors occurred while "
77+
f"attempting to deserialize data {data}: {errors}")
78+
7879
elif deserializer := cls.get_registered_object(type_) is not None:
7980
deserializer_fn: DatasetDeserializer = (
8081
deserializer() if isinstance(deserializer, type) else deserializer

0 commit comments

Comments
 (0)