1111from guidellm .config import settings
1212from guidellm .core .request import TextGenerationRequest
1313from guidellm .request .base import GenerationMode , RequestGenerator
14- from guidellm .utils import clean_text , filter_text , load_text , split_text
14+ from guidellm .utils import clean_text , filter_text , load_text , split_text , load_images
1515
1616__all__ = ["EmulatedConfig" , "EmulatedRequestGenerator" , "EndlessTokens" ]
1717
@@ -30,6 +30,7 @@ class EmulatedConfig:
3030 generated_tokens_variance (Optional[int]): Variance for generated tokens.
3131 generated_tokens_min (Optional[int]): Minimum number of generated tokens.
3232 generated_tokens_max (Optional[int]): Maximum number of generated tokens.
33+ images (Optional[int]): Number of input images.
3334 """
3435
3536 @staticmethod
@@ -47,7 +48,7 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
4748 """
4849 if not config :
4950 logger .debug ("Creating default configuration" )
50- return EmulatedConfig (prompt_tokens = 1024 , generated_tokens = 256 )
51+ return EmulatedConfig (prompt_tokens = 1024 , generated_tokens = 256 , images = 0 )
5152
5253 if isinstance (config , dict ):
5354 logger .debug ("Loading configuration from dict: {}" , config )
@@ -105,6 +106,8 @@ def create_config(config: Optional[Union[str, Path, Dict]]) -> "EmulatedConfig":
105106 generated_tokens_min : Optional [int ] = None
106107 generated_tokens_max : Optional [int ] = None
107108
109+ images : int = 0
110+
108111 @property
109112 def prompt_tokens_range (self ) -> Tuple [int , int ]:
110113 """
@@ -327,6 +330,8 @@ def __init__(
327330 settings .emulated_data .filter_start ,
328331 settings .emulated_data .filter_end ,
329332 )
333+ if self ._config .images > 0 :
334+ self ._images = load_images (settings .emulated_data .image_source )
330335 self ._rng = np .random .default_rng (random_seed )
331336
332337 # NOTE: Must be after all the parameters since the queue population
@@ -355,6 +360,7 @@ def create_item(self) -> TextGenerationRequest:
355360 logger .debug ("Creating new text generation request" )
356361 target_prompt_token_count = self ._config .sample_prompt_tokens (self ._rng )
357362 prompt = self .sample_prompt (target_prompt_token_count )
363+ images = self .sample_images ()
358364 prompt_token_count = len (self .tokenizer .tokenize (prompt ))
359365 output_token_count = self ._config .sample_output_tokens (self ._rng )
360366 logger .debug ("Generated prompt: {}" , prompt )
@@ -363,6 +369,7 @@ def create_item(self) -> TextGenerationRequest:
363369 prompt = prompt ,
364370 prompt_token_count = prompt_token_count ,
365371 output_token_count = output_token_count ,
372+ images = images ,
366373 )
367374
368375 def sample_prompt (self , tokens : int ) -> str :
@@ -395,3 +402,9 @@ def sample_prompt(self, tokens: int) -> str:
395402 right = mid
396403
397404 return self ._tokens .create_text (start_line_index , left )
405+
406+
407+ def sample_images (self ):
408+ image_indices = self ._rng .choice (len (self ._images ), size = self ._config .images , replace = False )
409+
410+ return [self ._images [i ] for i in image_indices ]
0 commit comments