2
2
import random
3
3
from collections .abc import Iterable , Iterator
4
4
from pathlib import Path
5
- from typing import Any , Literal , Optional , Union
5
+ from typing import Any , Optional , TypedDict , Union
6
6
7
7
import yaml
8
8
from datasets import (
@@ -63,6 +63,26 @@ class SyntheticDatasetConfig(BaseModel):
63
63
gt = 0 ,
64
64
default = None ,
65
65
)
66
+ turns : int = Field (
67
+ description = "The number of turns in the conversation." ,
68
+ gt = 0 ,
69
+ default = 1 ,
70
+ )
71
+ turns_stdev : Optional [int ] = Field (
72
+ description = "The standard deviation of the number of turns." ,
73
+ gt = 0 ,
74
+ default = None ,
75
+ )
76
+ turns_min : Optional [int ] = Field (
77
+ description = "The minimum number of turns in the conversation." ,
78
+ gt = 0 ,
79
+ default = None ,
80
+ )
81
+ turns_max : Optional [int ] = Field (
82
+ description = "The maximum number of turns in the conversation." ,
83
+ gt = 0 ,
84
+ default = None ,
85
+ )
66
86
samples : int = Field (
67
87
description = "The number of samples to generate for the dataset." ,
68
88
gt = 0 ,
@@ -118,14 +138,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
118
138
return SyntheticDatasetConfig (** config_dict )
119
139
120
140
121
- class SyntheticTextItemsGenerator (
122
- Iterable [
123
- dict [
124
- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
125
- Union [str , int ],
126
- ]
127
- ]
128
- ):
141
+ class SyntheticDatasetRow (TypedDict ):
142
+ prompt : list [str ]
143
+ prompt_tokens_count : list [int ]
144
+ output_tokens_count : list [int ]
145
+
146
+
147
+ class SyntheticTextItemsGenerator (Iterable [SyntheticDatasetRow ]):
129
148
def __init__ (
130
149
self ,
131
150
config : SyntheticDatasetConfig ,
@@ -141,12 +160,7 @@ def __init__(
141
160
142
161
def __iter__ (
143
162
self ,
144
- ) -> Iterator [
145
- dict [
146
- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
147
- Union [str , int ],
148
- ]
149
- ]:
163
+ ) -> Iterator [SyntheticDatasetRow ]:
150
164
prompt_tokens_sampler = IntegerRangeSampler (
151
165
average = self .config .prompt_tokens ,
152
166
variance = self .config .prompt_tokens_stdev ,
@@ -161,20 +175,33 @@ def __iter__(
161
175
max_value = self .config .output_tokens_max ,
162
176
random_seed = self .random_seed + 1 , # ensure diff dist from prompts
163
177
)
178
+ turns_sampler = IntegerRangeSampler (
179
+ average = self .config .turns ,
180
+ variance = self .config .turns_stdev ,
181
+ min_value = self .config .turns_min ,
182
+ max_value = self .config .turns_max ,
183
+ random_seed = self .random_seed + 7 , # ensure diff dist
184
+ )
164
185
# ensure diff distribution from output tokens
165
186
rand = random .Random (self .random_seed + 2 ) # noqa: S311
166
187
167
- for _ , prompt_tokens , output_tokens in zip (
168
- range (self .config .samples ),
169
- prompt_tokens_sampler ,
170
- output_tokens_sampler ,
171
- ):
172
- start_index = rand .randint (0 , len (self .text_creator .words ))
173
- yield {
174
- "prompt" : self ._create_prompt (prompt_tokens , start_index ),
175
- "prompt_tokens_count" : prompt_tokens ,
176
- "output_tokens_count" : output_tokens ,
188
+ for _ , turns in zip (range (self .config .samples ), turns_sampler ):
189
+ row : SyntheticDatasetRow = {
190
+ "prompt" : [],
191
+ "prompt_tokens_count" : [],
192
+ "output_tokens_count" : [],
177
193
}
194
+ for _ , prompt_tokens , output_tokens in zip (
195
+ range (turns ),
196
+ prompt_tokens_sampler ,
197
+ output_tokens_sampler ,
198
+ ):
199
+ start_index = rand .randint (0 , len (self .text_creator .words ))
200
+ row ["prompt" ].append (self ._create_prompt (prompt_tokens , start_index ))
201
+ row ["prompt_tokens_count" ].append (prompt_tokens )
202
+ row ["output_tokens_count" ].append (output_tokens )
203
+
204
+ yield row
178
205
179
206
def _create_prompt (self , prompt_tokens : int , start_index : int ) -> str :
180
207
if prompt_tokens <= 0 :
0 commit comments