11from __future__ import annotations
22
3- import contextlib
43from collections .abc import Callable
54from typing import Any , Protocol , Union , runtime_checkable
65
7- from datasets import Dataset , IterableDataset
6+ from datasets import Dataset , DatasetDict , IterableDataset , IterableDatasetDict
87from transformers import PreTrainedTokenizerBase
98
109from guidellm .data .utils import resolve_dataset_split
@@ -29,7 +28,7 @@ def __call__(
2928 processor_factory : Callable [[], PreTrainedTokenizerBase ],
3029 random_seed : int ,
3130 ** data_kwargs : dict [str , Any ],
32- ) -> dict [ str , list ] : ...
31+ ) -> Dataset | IterableDataset | DatasetDict | IterableDatasetDict : ...
3332
3433
3534class DatasetDeserializerFactory (
@@ -47,51 +46,16 @@ def deserialize(
4746 remove_columns : list [str ] | None = None ,
4847 ** data_kwargs : dict [str , Any ],
4948 ) -> Dataset | IterableDataset :
50- dataset = None
49+ dataset : Dataset
5150
5251 if type_ is None :
53- errors = []
54- # Note: There is no priority order for the deserializers, so all deserializers
55- # must be mutually exclusive to ensure deterministic behavior.
56- for name , deserializer in cls .registry .items ():
57- deserializer_fn : DatasetDeserializer = (
58- deserializer () if isinstance (deserializer , type ) else deserializer
59- )
60-
61- try :
62- with contextlib .suppress (DataNotSupportedError ):
63- dataset = deserializer_fn (
64- data = data ,
65- processor_factory = processor_factory ,
66- random_seed = random_seed ,
67- ** data_kwargs ,
68- )
69- except Exception as e :
70- errors .append (e )
71-
72- if dataset is not None :
73- break # Found one that works. Continuing could overwrite it.
74-
75- if dataset is None and len (errors ) > 0 :
76- raise DataNotSupportedError (f"data deserialization failed; { len (errors )} errors occurred while "
77- f"attempting to deserialize data { data } : { errors } " )
78-
79- elif deserializer := cls .get_registered_object (type_ ) is not None :
80- deserializer_fn : DatasetDeserializer = (
81- deserializer () if isinstance (deserializer , type ) else deserializer
52+ dataset = cls ._deserialize_with_registered_deserializers (
53+ data , processor_factory , random_seed , ** data_kwargs
8254 )
8355
84- dataset = deserializer_fn (
85- data = data ,
86- processor_factory = processor_factory ,
87- random_seed = random_seed ,
88- ** data_kwargs ,
89- )
90-
91- if dataset is None :
92- raise DataNotSupportedError (
93- f"No suitable deserializer found for data { data } "
94- f"with kwargs { data_kwargs } and deserializer type { type_ } ."
56+ else :
57+ dataset = cls ._deserialize_with_specified_deserializer (
58+ data , type_ , processor_factory , random_seed , ** data_kwargs
9559 )
9660
9761 if resolve_split :
@@ -107,3 +71,74 @@ def deserialize(
10771 dataset = dataset .remove_columns (remove_columns )
10872
10973 return dataset
74+
75+ @classmethod
76+ def _deserialize_with_registered_deserializers (
77+ cls ,
78+ data : Any ,
79+ processor_factory : Callable [[], PreTrainedTokenizerBase ],
80+ random_seed : int = 42 ,
81+ ** data_kwargs : dict [str , Any ],
82+ ) -> Dataset :
83+ if cls .registry is None :
84+ raise RuntimeError ("registry is None; cannot deserialize dataset" )
85+ dataset : Dataset | None = None
86+
87+ errors : dict [str , Exception ] = {}
88+ # Note: There is no priority order for the deserializers, so all deserializers
89+ # must be mutually exclusive to ensure deterministic behavior.
90+ for _name , deserializer in cls .registry .items ():
91+ deserializer_fn : DatasetDeserializer = (
92+ deserializer () if isinstance (deserializer , type ) else deserializer
93+ )
94+
95+ try :
96+ dataset = deserializer_fn (
97+ data = data ,
98+ processor_factory = processor_factory ,
99+ random_seed = random_seed ,
100+ ** data_kwargs ,
101+ )
102+ except Exception as e : # noqa: BLE001 # The exceptions are saved.
103+ errors [_name ] = e
104+
105+ if dataset is not None :
106+ return dataset # Success
107+
108+ if len (errors ) > 0 :
109+ err_msgs = ""
110+ def sort_key (item ):
111+ return (isinstance (item [1 ], DataNotSupportedError ), item [0 ])
112+ for key , err in sorted (errors .items (), key = sort_key ):
113+ err_msgs += f"\n - Deserializer '{ key } ': ({ type (err ).__name__ } ) { err } "
114+ raise ValueError (
115+ "Data deserialization failed, likely because the input doesn't "
116+ f"match any of the input formats. See the { len (errors )} error(s) that "
117+ f"occurred while attempting to deserialize the data { data } :{ err_msgs } "
118+ )
119+ return dataset
120+
121+ @classmethod
122+ def _deserialize_with_specified_deserializer (
123+ cls ,
124+ data : Any ,
125+ type_ : str ,
126+ processor_factory : Callable [[], PreTrainedTokenizerBase ],
127+ random_seed : int = 42 ,
128+ ** data_kwargs : dict [str , Any ],
129+ ) -> Dataset :
130+ deserializer_from_type = cls .get_registered_object (type_ )
131+ if deserializer_from_type is None :
132+ raise ValueError (f"Deserializer type '{ type_ } ' is not registered." )
133+ if isinstance (deserializer_from_type , type ):
134+ deserializer_fn = deserializer_from_type ()
135+ else :
136+ deserializer_fn = deserializer_from_type
137+
138+ return deserializer_fn (
139+ data = data ,
140+ processor_factory = processor_factory ,
141+ random_seed = random_seed ,
142+ ** data_kwargs ,
143+ )
144+
0 commit comments