25
25
26
26
27
27
class SyntheticTextDatasetConfig (StandardBaseModel ):
28
+ prefix_tokens : int = Field (
29
+ description = "The number of shared prefix tokens to prepend to each prompt." ,
30
+ ge = 0 ,
31
+ default = 0 ,
32
+ )
28
33
prompt_tokens : int = Field (
29
34
description = "The average number of text tokens generated for prompts." ,
30
35
gt = 0 ,
@@ -104,20 +109,29 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
104
109
)
105
110
)
106
111
112
+ # Create a shared prefix if specified
113
+ if self .config .prefix_tokens > 0 :
114
+ prefix = self ._create_prompt (self .config .prefix_tokens , faker )
115
+ else :
116
+ prefix = "" # Always have a prefix key for consistency
117
+
107
118
while True :
108
119
prompt_tokens_count = next (prompt_tokens_sampler )
109
120
output_tokens_count = next (output_tokens_sampler )
110
121
111
122
yield {
123
+ "prefix" : prefix ,
112
124
"prompt" : self ._create_prompt (
113
- prompt_tokens_count , samples_generated , faker
125
+ prompt_tokens_count , faker , f" { samples_generated } "
114
126
),
115
127
"prompt_tokens_count" : prompt_tokens_count ,
116
128
"output_tokens_count" : output_tokens_count ,
117
129
}
118
130
samples_generated += 1
119
131
120
- def _create_prompt (self , prompt_tokens_count : int , index : int , faker : Faker ) -> str :
132
+ def _create_prompt (
133
+ self , prompt_tokens_count : int , faker : Faker , unique : str = ""
134
+ ) -> str :
121
135
prompt_token_ids = []
122
136
avg_chars_per_token = 5
123
137
margin_of_safety = 1.5
@@ -128,7 +142,7 @@ def _create_prompt(self, prompt_tokens_count: int, index: int, faker: Faker) ->
128
142
num_chars = (
129
143
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
130
144
)
131
- text = f" { index } " + faker .text (max_nb_chars = num_chars )
145
+ text = unique + faker .text (max_nb_chars = num_chars )
132
146
prompt_token_ids = self .processor .encode (text )
133
147
134
148
return self .processor .decode (
@@ -166,6 +180,7 @@ def __call__(
166
180
),
167
181
features = Features (
168
182
{
183
+ "prefix" : Value ("string" ),
169
184
"prompt" : Value ("string" ),
170
185
"prompt_tokens_count" : Value ("int32" ),
171
186
"output_tokens_count" : Value ("int32" ),
0 commit comments