Skip to content

Commit 81817f4

Browse files
committed
Address review comments
Signed-off-by: Jared O'Connell <[email protected]>
1 parent ad41a66 commit 81817f4

File tree

5 files changed

+25
-50
lines changed

5 files changed

+25
-50
lines changed

src/guidellm/data/deserializers/synthetic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def _create_prompt(
183183

184184
while len(prompt_token_ids) < prompt_tokens_count:
185185
attempts += 1
186-
num_chars = math.ceil(
186+
num_chars = int(
187187
prompt_tokens_count * avg_chars_per_token * margin_of_safety * attempts
188188
)
189189
text = unique + faker.text(max_nb_chars=num_chars)

src/guidellm/data/loaders.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
__all__ = ["DataLoader", "DatasetsIterator"]
1818

19-
from guidellm.schemas import GenerationRequest
2019

2120

2221
class DatasetsIterator(TorchIterableDataset):
@@ -100,11 +99,10 @@ def generator(
10099
continue
101100

102101
for preprocessor in self.preprocessors:
103-
processed_row = preprocessor(row)
104-
if isinstance(processed_row, GenerationRequest):
105-
yield processed_row
106-
else:
107-
row = processed_row
102+
# This can assign a GenerationRequest, which would then be
103+
# passed into the preprocessor, which is a type violation.
104+
# This should be fixed at some point.
105+
row = preprocessor(row) # type: ignore[assignment]
108106
yield row
109107
except Exception as err: # noqa: BLE001 # Exception logged
110108
logger.error(f"Skipping data row due to error: {err}")

src/guidellm/data/preprocessors/formatters.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
DatasetPreprocessor,
88
PreprocessorRegistry,
99
)
10-
from guidellm.data.utils import text_stats
1110
from guidellm.schemas import GenerationRequest, GenerationRequestArguments, UsageMetrics
1211

1312
__all__ = [
@@ -102,10 +101,10 @@ def __call__(
102101
prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre)
103102
text = "".join(txt for txt in columns.get("text_column", []) if txt)
104103
if prefix or text:
105-
arguments.body["prompt"] = prefix + text
106-
stats = text_stats(arguments.body["prompt"])
107-
input_metrics.text_characters = stats.get("num_chars")
108-
input_metrics.text_words = stats.get("num_words")
104+
prompt = prefix + text
105+
arguments.body["prompt"] = prompt
106+
input_metrics.text_characters = len(prompt)
107+
input_metrics.text_words = len(prompt.split())
109108

110109
return GenerationRequest(
111110
request_type="text_completions",
@@ -198,27 +197,25 @@ def __call__( # noqa: C901, PLR0912, PLR0915
198197
if not prefix:
199198
continue
200199

201-
stats = text_stats(prefix)
202-
if (num_chars := stats.get("num_chars")) is not None:
203-
input_metrics.text_characters = (
204-
input_metrics.text_characters or 0
205-
) + num_chars
206-
if (num_words := stats.get("num_words")) is not None:
207-
input_metrics.text_words = (input_metrics.text_words or 0) + num_words
200+
input_metrics.text_characters = (
201+
input_metrics.text_characters or 0
202+
) + len(prefix)
203+
204+
input_metrics.text_words = (input_metrics.text_words or 0) + \
205+
len(prefix.split())
208206

209207
arguments.body["messages"].append({"role": "system", "content": prefix})
210208

211209
for text in columns.get("text_column", []):
212210
if not text:
213211
continue
214212

215-
stats = text_stats(text)
216-
if (num_chars := stats.get("num_chars")) is not None:
217-
input_metrics.text_characters = (
218-
input_metrics.text_characters or 0
219-
) + num_chars
220-
if (num_words := stats.get("num_words")) is not None:
221-
input_metrics.text_words = (input_metrics.text_words or 0) + num_words
213+
input_metrics.text_characters = (
214+
input_metrics.text_characters or 0
215+
) + len(text)
216+
input_metrics.text_words = (
217+
input_metrics.text_words or 0
218+
) + len(text.split())
222219

223220
arguments.body["messages"].append(
224221
{"role": "user", "content": [{"type": "text", "text": text}]}
@@ -395,10 +392,10 @@ def __call__( # noqa: C901
395392
prefix = "".join(pre for pre in columns.get("prefix_column", []) if pre)
396393
text = "".join(txt for txt in columns.get("text_column", []) if txt)
397394
if prefix or text:
398-
arguments.body["prompt"] = prefix + text
399-
stats = text_stats(arguments.body["prompt"])
400-
input_metrics.text_characters = stats.get("num_chars")
401-
input_metrics.text_words = stats.get("num_words")
395+
prompt = prefix + text
396+
arguments.body["prompt"] = prompt
397+
input_metrics.text_characters = len(prompt)
398+
input_metrics.text_words = len(prompt.split())
402399

403400
return GenerationRequest(
404401
request_type="audio_transcriptions",
Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
from .dataset import DEFAULT_SPLITS, resolve_dataset_split
2-
from .functions import (
3-
text_stats,
4-
)
52

63
__all__ = [
74
"DEFAULT_SPLITS",
85
"resolve_dataset_split",
9-
"text_stats",
106
]

src/guidellm/data/utils/functions.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)