2525
2626
2727class SyntheticTextDatasetConfig (StandardBaseModel ):
28+ prefix_tokens : int = Field (
29+ description = "The number of shared prefix tokens to prepend to each prompt." ,
30+ ge = 0 ,
31+ default = 0 ,
32+ )
2833 prompt_tokens : int = Field (
2934 description = "The average number of text tokens generated for prompts." ,
3035 gt = 0 ,
@@ -104,20 +109,29 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
104109 )
105110 )
106111
112+ # Create a shared prefix if specified
113+ if self .config .prefix_tokens > 0 :
114+ prefix = self ._create_prompt (self .config .prefix_tokens , faker )
115+ else :
116+ prefix = "" # Always have a prefix key for consistency
117+
107118 while True :
108119 prompt_tokens_count = next (prompt_tokens_sampler )
109120 output_tokens_count = next (output_tokens_sampler )
110121
111122 yield {
123+ "prefix" : prefix ,
112124 "prompt" : self ._create_prompt (
113- prompt_tokens_count , samples_generated , faker
125+ prompt_tokens_count , faker , f" { samples_generated } "
114126 ),
115127 "prompt_tokens_count" : prompt_tokens_count ,
116128 "output_tokens_count" : output_tokens_count ,
117129 }
118130 samples_generated += 1
119131
120- def _create_prompt (self , prompt_tokens_count : int , index : int , faker : Faker ) -> str :
132+ def _create_prompt (
133+ self , prompt_tokens_count : int , faker : Faker , unique : str = ""
134+ ) -> str :
121135 prompt_token_ids = []
122136 avg_chars_per_token = 5
123137 margin_of_safety = 1.5
@@ -128,7 +142,7 @@ def _create_prompt(self, prompt_tokens_count: int, index: int, faker: Faker) ->
128142 num_chars = (
129143 prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
130144 )
131- text = f" { index } " + faker .text (max_nb_chars = num_chars )
145+ text = unique + faker .text (max_nb_chars = num_chars )
132146 prompt_token_ids = self .processor .encode (text )
133147
134148 return self .processor .decode (
@@ -166,6 +180,7 @@ def __call__(
166180 ),
167181 features = Features (
168182 {
183+ "prefix" : Value ("string" ),
169184 "prompt" : Value ("string" ),
170185 "prompt_tokens_count" : Value ("int32" ),
171186 "output_tokens_count" : Value ("int32" ),
0 commit comments