|
7 | 7 | DatasetPreprocessor, |
8 | 8 | PreprocessorRegistry, |
9 | 9 | ) |
10 | | -from guidellm.data.utils import text_stats |
11 | 10 | from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics |
12 | 11 |
|
13 | 12 | __all__ = [ |
@@ -102,10 +101,10 @@ def __call__( |
102 | 101 | prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) |
103 | 102 | text = "".join(txt for txt in columns.get("text_column", []) if txt) |
104 | 103 | if prefix or text: |
105 | | - arguments.body["prompt"] = prefix + text |
106 | | - stats = text_stats(arguments.body["prompt"]) |
107 | | - input_metrics.text_characters = stats.get("num_chars") |
108 | | - input_metrics.text_words = stats.get("num_words") |
| 104 | + prompt = prefix + text |
| 105 | + arguments.body["prompt"] = prompt |
| 106 | + input_metrics.text_characters = len(prompt) |
| 107 | + input_metrics.text_words = len(prompt.split()) |
109 | 108 |
|
110 | 109 | return GenerationRequest( |
111 | 110 | request_type="text_completions", |
@@ -198,27 +197,25 @@ def __call__( # noqa: C901, PLR0912, PLR0915 |
198 | 197 | if not prefix: |
199 | 198 | continue |
200 | 199 |
|
201 | | - stats = text_stats(prefix) |
202 | | - if (num_chars := stats.get("num_chars")) is not None: |
203 | | - input_metrics.text_characters = ( |
204 | | - input_metrics.text_characters or 0 |
205 | | - ) + num_chars |
206 | | - if (num_words := stats.get("num_words")) is not None: |
207 | | - input_metrics.text_words = (input_metrics.text_words or 0) + num_words |
| 200 | + input_metrics.text_characters = ( |
| 201 | + input_metrics.text_characters or 0 |
| 202 | + ) + len(prefix) |
| 203 | + |
| 204 | + input_metrics.text_words = (input_metrics.text_words or 0) + \ |
| 205 | + len(prefix.split()) |
208 | 206 |
|
209 | 207 | arguments.body["messages"].append({"role": "system", "content": prefix}) |
210 | 208 |
|
211 | 209 | for text in columns.get("text_column", []): |
212 | 210 | if not text: |
213 | 211 | continue |
214 | 212 |
|
215 | | - stats = text_stats(text) |
216 | | - if (num_chars := stats.get("num_chars")) is not None: |
217 | | - input_metrics.text_characters = ( |
218 | | - input_metrics.text_characters or 0 |
219 | | - ) + num_chars |
220 | | - if (num_words := stats.get("num_words")) is not None: |
221 | | - input_metrics.text_words = (input_metrics.text_words or 0) + num_words |
| 213 | + input_metrics.text_characters = ( |
| 214 | + input_metrics.text_characters or 0 |
| 215 | + ) + len(text) |
| 216 | + input_metrics.text_words = ( |
| 217 | + input_metrics.text_words or 0 |
| 218 | + ) + len(text.split()) |
222 | 219 |
|
223 | 220 | arguments.body["messages"].append( |
224 | 221 | {"role": "user", "content": [{"type": "text", "text": text}]} |
@@ -395,10 +392,10 @@ def __call__( # noqa: C901 |
395 | 392 | prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre) |
396 | 393 | text = "".join(txt for txt in columns.get("text_column", []) if txt) |
397 | 394 | if prefix or text: |
398 | | - arguments.body["prompt"] = prefix + text |
399 | | - stats = text_stats(arguments.body["prompt"]) |
400 | | - input_metrics.text_characters = stats.get("num_chars") |
401 | | - input_metrics.text_words = stats.get("num_words") |
| 395 | + prompt = prefix + text |
| 396 | + arguments.body["prompt"] = prompt |
| 397 | + input_metrics.text_characters = len(prompt) |
| 398 | + input_metrics.text_words = len(prompt.split()) |
402 | 399 |
|
403 | 400 | return GenerationRequest( |
404 | 401 | request_type="audio_transcriptions", |
|
0 commit comments