1313
1414from collections .abc import Callable , Mapping , MutableMapping
1515from pathlib import Path
16- from typing import Any , Literal
16+ from typing import Any , Literal , TypeVar
1717
1818from torch .utils .data import Sampler
1919from transformers import PreTrainedTokenizerBase
3636from guidellm .benchmark .schemas .base import TransientPhaseConfig
3737from guidellm .data import (
3838 DataLoader ,
39+ DatasetFinalizer ,
3940 DatasetPreprocessor ,
41+ FinalizerRegistry ,
4042 GenerativeRequestCollator ,
4143 PreprocessorRegistry ,
4244 ProcessorFactory ,
43- RequestFormatter ,
4445)
45- from guidellm .data .preprocessors import GenerativeColumnMapper
4646from guidellm .scheduler import (
4747 ConstraintInitializer ,
4848 NonDistributedEnvironment ,
4949 StrategyType ,
5050)
5151from guidellm .schemas import GenerationRequest , GenerationResponse
52- from guidellm .settings import settings
5352from guidellm .utils import Console , InfoMixin
53+ from guidellm .utils .registry import RegistryMixin
5454
5555__all__ = [
5656 "benchmark_generative_text" ,
@@ -178,19 +178,66 @@ async def resolve_processor(
178178 return processor
179179
180180
181+ BaseTypeT = TypeVar ("BaseTypeT" )
182+
183+
184+ def resolve_item_from_registry (
185+ base_type : type [BaseTypeT ],
186+ registry : type [RegistryMixin ],
187+ item : Any ,
188+ extras : dict [str , Any ] | None = None ,
189+ ) -> BaseTypeT :
190+ """
191+ Resolve an item from a registry, instantiating it if necessary.
192+
193+ :param base_type: The expected base type of the item
194+ :param item: The item to resolve, either an instance or a string identifier
195+ :param registry: The registry to use for resolving string identifiers
196+ :return: The resolved item as an instance of the base type
197+ :raises ValueError: If the item cannot be resolved from the registry
198+ :raises TypeError: If the resolved item is not of the expected base type
199+ """
200+ if isinstance (item , base_type ):
201+ return item
202+ else :
203+ if isinstance (item , str ):
204+ item_type = item
205+ kwargs = {}
206+ else :
207+ item_dict = dict (item )
208+ item_type = item_dict .pop ("type" , None )
209+ if item_type is None :
210+ raise ValueError (
211+ f"Item dictionary must contain a 'type' key to resolve from "
212+ f"{ registry .__class__ .__name__ } ."
213+ )
214+ kwargs = item_dict
215+
216+ if (item_class := registry .get_registered_object (item_type )) is None :
217+ raise ValueError (
218+ f"Item type '{ item_type } ' is not registered in the "
219+ f"{ registry .__class__ .__name__ } ."
220+ )
221+ if not issubclass (item_class , base_type ):
222+ raise TypeError (
223+ f"Resolved item type '{ item_type } ' is not a subclass of "
224+ f"{ base_type .__name__ } ."
225+ )
226+ if extras :
227+ kwargs .update (extras )
228+ return item_class (** kwargs )
229+
230+
181231async def resolve_request_loader (
182232 data : list [Any ],
183233 model : str ,
234+ request_type : str ,
184235 data_args : list [dict [str , Any ]] | None ,
185236 data_samples : int ,
186237 processor : ProcessorInputT | None ,
187238 processor_args : dict [str , Any ] | None ,
188- data_column_mapper : (
189- DatasetPreprocessor
190- | dict [str , str | list [str ]]
191- | Literal ["generative_column_mapper" ]
192- ),
193- data_request_formatter : (RequestFormatter | dict [str , str ] | str ),
239+ data_preprocessors : list [DatasetPreprocessor | dict [str , str | list [str ]] | str ],
240+ data_finalizer : (DatasetFinalizer | dict [str , Any ] | str ),
194241 data_collator : Callable | Literal ["generative" ] | None ,
195242 data_sampler : Sampler [int ] | Literal ["shuffle" ] | None ,
196243 data_num_workers : int | None ,
@@ -232,54 +279,22 @@ async def resolve_request_loader(
232279 else None
233280 )
234281
235- data_column_mapper_instance : DatasetPreprocessor
236- if isinstance (data_column_mapper , DatasetPreprocessor ):
237- data_column_mapper_instance = data_column_mapper
238- else :
239- column_mappings = (
240- data_column_mapper if isinstance (data_column_mapper , dict ) else None
241- )
242- data_column_mapper_instance = GenerativeColumnMapper (
243- column_mappings = column_mappings # type: ignore[arg-type]
244- )
245-
246- data_request_formatter_instance : RequestFormatter
247- if isinstance (data_request_formatter , RequestFormatter ):
248- data_request_formatter_instance = data_request_formatter
249- else :
250- if isinstance (data_request_formatter , str ):
251- request_type = data_request_formatter
252- formatter_kwargs : dict [str , Any ] = {}
253- else :
254- # Extract request_type from formatter dictionary
255- formatter_dict = dict (data_request_formatter )
256- request_type = formatter_dict .pop ("request_type" , settings .preferred_route )
257- formatter_kwargs = formatter_dict
258-
259- if (
260- formatter_class := PreprocessorRegistry .get_registered_object (request_type )
261- ) is None :
262- raise ValueError (
263- f"Request formatter '{ request_type } ' is not registered in the "
264- f"PreprocessorRegistry."
265- )
266- if not issubclass (formatter_class , RequestFormatter ):
267- raise TypeError (
268- f"Request formatter '{ request_type } ' is not a subclass of "
269- f"RequestFormatter."
270- )
271-
272- data_request_formatter_instance = formatter_class (
273- model = model ,
274- ** formatter_kwargs ,
275- )
276-
277- # Cast to proper types for the DataLoader preprocessors list
278282 preprocessors_list : list [DatasetPreprocessor ] = [
279- data_column_mapper_instance ,
280- data_request_formatter_instance ,
283+ resolve_item_from_registry (
284+ DatasetPreprocessor , # type: ignore [type-abstract]
285+ PreprocessorRegistry ,
286+ preprocessor ,
287+ )
288+ for preprocessor in data_preprocessors
281289 ]
282290
291+ finalizer_instance = resolve_item_from_registry (
292+ DatasetFinalizer , # type: ignore [type-abstract]
293+ FinalizerRegistry ,
294+ data_finalizer ,
295+ extras = {"request_type" : request_type },
296+ )
297+
283298 request_loader : DataLoader [GenerationRequest ] = DataLoader (
284299 data = data ,
285300 data_args = data_args ,
@@ -289,6 +304,7 @@ async def resolve_request_loader(
289304 processor_args = processor_args ,
290305 ),
291306 preprocessors = preprocessors_list ,
307+ finalizer = finalizer_instance ,
292308 collator = (
293309 data_collator if callable (data_collator ) else GenerativeRequestCollator ()
294310 ),
@@ -460,12 +476,13 @@ async def benchmark_generative_text(
460476 request_loader = await resolve_request_loader (
461477 data = args .data ,
462478 model = model ,
479+ request_type = args .request_type ,
463480 data_args = args .data_args ,
464481 data_samples = args .data_samples ,
465482 processor = processor ,
466483 processor_args = args .processor_args ,
467- data_column_mapper = args .data_column_mapper ,
468- data_request_formatter = args .data_request_formatter ,
484+ data_preprocessors = args .data_preprocessors ,
485+ data_finalizer = args .data_finalizer ,
469486 data_collator = args .data_collator ,
470487 data_sampler = args .data_sampler ,
471488 data_num_workers = args .data_num_workers ,
0 commit comments