Skip to content

Commit 77df1c8

Browse files
committed
Address review comments
Moved short prompt strategies to a static class Assisted-by: Cursor AI Signed-off-by: Jared O'Connell <[email protected]>
1 parent 1eea713 commit 77df1c8

File tree

2 files changed

+148
-127
lines changed

2 files changed

+148
-127
lines changed

src/guidellm/data/entrypoints.py

Lines changed: 128 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Callable, Iterator
33
from enum import Enum
44
from pathlib import Path
5-
from typing import Any
5+
from typing import Any, cast
66

77
from datasets import Dataset
88
from loguru import logger
@@ -29,120 +29,136 @@ class ShortPromptStrategy(str, Enum):
2929
ERROR = "error"
3030

3131

32-
def handle_ignore_strategy(
33-
current_prompt: str,
34-
min_prompt_tokens: int,
35-
tokenizer: PreTrainedTokenizerBase,
36-
**_kwargs,
37-
) -> str | None:
38-
"""
39-
Ignores prompts that are shorter than the required minimum token length.
32+
class ShortPromptStrategyHandler:
33+
"""Handler class for short prompt strategies."""
4034

41-
:param current_prompt: The input prompt string.
42-
:param min_prompt_tokens: Minimum required token count.
43-
:param tokenizer: Tokenizer used to count tokens.
44-
:return: The prompt if it meets the length, otherwise None.
45-
"""
35+
@staticmethod
36+
def handle_ignore(
37+
current_prompt: str,
38+
min_prompt_tokens: int,
39+
tokenizer: PreTrainedTokenizerBase,
40+
**_kwargs,
41+
) -> str | None:
42+
"""
43+
Ignores prompts that are shorter than the required minimum token length.
4644
47-
if len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
48-
logger.warning("Prompt too short, ignoring")
49-
return None
50-
return current_prompt
45+
:param current_prompt: The input prompt string.
46+
:param min_prompt_tokens: Minimum required token count.
47+
:param tokenizer: Tokenizer used to count tokens.
48+
:return: The prompt if it meets the length, otherwise None.
49+
"""
5150

52-
53-
def handle_concatenate_strategy(
54-
current_prompt: str,
55-
min_prompt_tokens: int,
56-
dataset_iterator: Iterator[dict[str, Any]],
57-
prompt_column: str,
58-
tokenizer: PreTrainedTokenizerBase,
59-
concat_delimiter: str,
60-
**_kwargs,
61-
) -> str | None:
62-
"""
63-
Concatenates prompts until the minimum token requirement is met.
64-
65-
:param current_prompt: The initial prompt.
66-
:param min_prompt_tokens: Target minimum token length.
67-
:param dataset_iterator: Iterator to fetch more prompts.
68-
:param prompt_column: Column key for prompt extraction.
69-
:param tokenizer: Tokenizer used to count tokens.
70-
:param concat_delimiter: Delimiter to use between prompts.
71-
:return: Concatenated prompt or None if not enough data.
72-
"""
73-
74-
tokens_len = len(tokenizer.encode(current_prompt))
75-
while tokens_len < min_prompt_tokens:
76-
try:
77-
next_row = next(dataset_iterator)
78-
except StopIteration:
79-
logger.warning(
80-
"Could not concatenate enough prompts to reach minimum length, ignoring"
81-
)
51+
if len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
52+
logger.warning("Prompt too short, ignoring")
8253
return None
83-
current_prompt += concat_delimiter + next_row[prompt_column]
84-
tokens_len = len(tokenizer.encode(current_prompt))
85-
return current_prompt
54+
return current_prompt
55+
56+
@staticmethod
57+
def handle_concatenate(
58+
current_prompt: str,
59+
min_prompt_tokens: int,
60+
dataset_iterator: Iterator[dict[str, Any]],
61+
prompt_column: str,
62+
tokenizer: PreTrainedTokenizerBase,
63+
concat_delimiter: str,
64+
**_kwargs,
65+
) -> str | None:
66+
"""
67+
Concatenates prompts until the minimum token requirement is met.
68+
69+
:param current_prompt: The initial prompt.
70+
:param min_prompt_tokens: Target minimum token length.
71+
:param dataset_iterator: Iterator to fetch more prompts.
72+
:param prompt_column: Column key for prompt extraction.
73+
:param tokenizer: Tokenizer used to count tokens.
74+
:param concat_delimiter: Delimiter to use between prompts.
75+
:return: Concatenated prompt or None if not enough data.
76+
"""
8677

78+
tokens_len = len(tokenizer.encode(current_prompt))
79+
while tokens_len < min_prompt_tokens:
80+
try:
81+
next_row = next(dataset_iterator)
82+
except StopIteration:
83+
logger.warning(
84+
"Could not concatenate enough prompts to reach minimum "
85+
"length, ignoring"
86+
)
87+
return None
88+
current_prompt += concat_delimiter + next_row[prompt_column]
89+
tokens_len = len(tokenizer.encode(current_prompt))
90+
return current_prompt
91+
92+
@staticmethod
93+
def handle_pad(
94+
current_prompt: str,
95+
min_prompt_tokens: int,
96+
tokenizer: PreTrainedTokenizerBase,
97+
pad_char: str,
98+
pad_multiplier: int = 2,
99+
**_kwargs,
100+
) -> str:
101+
"""
102+
Pads the prompt with a character until it reaches the minimum token length.
103+
104+
:param current_prompt: The input prompt.
105+
:param min_prompt_tokens: Desired minimum token count.
106+
:param tokenizer: Tokenizer used to count tokens.
107+
:param pad_char: Character used for padding.
108+
:param pad_multiplier: Multiplier for padding character length.
109+
:return: Padded prompt string.
110+
"""
111+
tokens = tokenizer.encode(current_prompt)
112+
pad_count = 1
113+
prompt = current_prompt
114+
while len(tokens) < min_prompt_tokens:
115+
prompt += pad_char * pad_count
116+
tokens = tokenizer.encode(prompt)
117+
pad_count *= pad_multiplier
118+
return prompt
119+
120+
@staticmethod
121+
def handle_error(
122+
current_prompt: str,
123+
min_prompt_tokens: int,
124+
tokenizer: PreTrainedTokenizerBase,
125+
**_kwargs,
126+
) -> str | None:
127+
"""
128+
Raises an error if the prompt is too short.
129+
130+
:param current_prompt: The input prompt.
131+
:param min_prompt_tokens: Required token count.
132+
:param tokenizer: Tokenizer used to count tokens.
133+
:return: The input prompt if valid.
134+
:raises PromptTooShortError: If the prompt is too short.
135+
"""
136+
137+
prompt_len = len(tokenizer.encode(current_prompt))
138+
if prompt_len < min_prompt_tokens:
139+
raise PromptTooShortError(
140+
f"Found too short prompt: {current_prompt}, with length: {prompt_len}. "
141+
f"Minimum length required: {min_prompt_tokens}.",
142+
)
143+
return current_prompt
87144

88-
def handle_pad_strategy(
89-
current_prompt: str,
90-
min_prompt_tokens: int,
91-
tokenizer: PreTrainedTokenizerBase,
92-
pad_char: str,
93-
pad_multiplier: int = 2,
94-
**_kwargs,
95-
) -> str:
96-
"""
97-
Pads the prompt with a character until it reaches the minimum token length.
98-
99-
:param current_prompt: The input prompt.
100-
:param min_prompt_tokens: Desired minimum token count.
101-
:param tokenizer: Tokenizer used to count tokens.
102-
:param pad_char: Character used for padding.
103-
:param pad_multiplier: Multiplier for padding character length.
104-
:return: Padded prompt string.
105-
"""
106-
tokens = tokenizer.encode(current_prompt)
107-
pad_count = 1
108-
prompt = current_prompt
109-
while len(tokens) < min_prompt_tokens:
110-
prompt += pad_char * pad_count
111-
tokens = tokenizer.encode(prompt)
112-
pad_count *= pad_multiplier
113-
return prompt
114-
115-
116-
def handle_error_strategy(
117-
current_prompt: str,
118-
min_prompt_tokens: int,
119-
tokenizer: PreTrainedTokenizerBase,
120-
**_kwargs,
121-
) -> str | None:
122-
"""
123-
Raises an error if the prompt is too short.
124-
125-
:param current_prompt: The input prompt.
126-
:param min_prompt_tokens: Required token count.
127-
:param tokenizer: Tokenizer used to count tokens.
128-
:return: The input prompt if valid.
129-
:raises PromptTooShortError: If the prompt is too short.
130-
"""
145+
@classmethod
146+
def get_strategy_handler(cls, strategy: ShortPromptStrategy) -> Callable[..., Any]:
147+
"""
148+
Get the handler for a specific strategy.
131149
132-
prompt_len = len(tokenizer.encode(current_prompt))
133-
if prompt_len < min_prompt_tokens:
134-
raise PromptTooShortError(
135-
f"Found too short prompt: {current_prompt}, with length: {prompt_len}. "
136-
f"Minimum length required: {min_prompt_tokens}.",
137-
)
138-
return current_prompt
150+
:param strategy: The short prompt strategy to get the handler for.
151+
:return: The handler callable for the specified strategy.
152+
"""
153+
return cast("Callable[..., Any]", STRATEGY_HANDLERS[strategy])
139154

140155

141-
STRATEGY_HANDLERS: dict[ShortPromptStrategy, Callable] = {
142-
ShortPromptStrategy.IGNORE: handle_ignore_strategy,
143-
ShortPromptStrategy.CONCATENATE: handle_concatenate_strategy,
144-
ShortPromptStrategy.PAD: handle_pad_strategy,
145-
ShortPromptStrategy.ERROR: handle_error_strategy,
156+
# Initialize STRATEGY_HANDLERS after class definition to allow method references
157+
STRATEGY_HANDLERS = {
158+
ShortPromptStrategy.IGNORE: ShortPromptStrategyHandler.handle_ignore,
159+
ShortPromptStrategy.CONCATENATE: ShortPromptStrategyHandler.handle_concatenate,
160+
ShortPromptStrategy.PAD: ShortPromptStrategyHandler.handle_pad,
161+
ShortPromptStrategy.ERROR: ShortPromptStrategyHandler.handle_error,
146162
}
147163

148164

@@ -245,7 +261,9 @@ def process_dataset(
245261
)
246262

247263
# Setup column mapper
248-
column_mapper = GenerativeColumnMapper(column_mappings=data_column_mapper) # type: ignore[arg-type]
264+
column_mapper = GenerativeColumnMapper(
265+
column_mappings=data_column_mapper # type: ignore[arg-type]
266+
)
249267
column_mapper.setup_data(
250268
datasets=[dataset],
251269
data_args=[data_args or {}],
@@ -265,7 +283,9 @@ def process_dataset(
265283
# Process dataset
266284
dataset_iterator = iter(dataset)
267285
processed_prompts = []
268-
prompt_handler = STRATEGY_HANDLERS[short_prompt_strategy]
286+
prompt_handler = ShortPromptStrategyHandler.get_strategy_handler(
287+
short_prompt_strategy
288+
)
269289

270290
for row in dataset_iterator:
271291
processed_row = _process_single_row(

tests/unit/data/test_entrypoints.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,9 @@
1616
from transformers import PreTrainedTokenizerBase
1717

1818
from guidellm.data.entrypoints import (
19-
STRATEGY_HANDLERS,
2019
PromptTooShortError,
2120
ShortPromptStrategy,
22-
handle_concatenate_strategy,
23-
handle_error_strategy,
24-
handle_ignore_strategy,
25-
handle_pad_strategy,
21+
ShortPromptStrategyHandler,
2622
process_dataset,
2723
push_dataset_to_hub,
2824
)
@@ -1742,54 +1738,58 @@ class TestShortPromptStrategyHandlers:
17421738

17431739
@pytest.mark.sanity
17441740
def test_handle_ignore_strategy_too_short(self, tokenizer_mock):
1745-
"""Test handle_ignore_strategy returns None for short prompts."""
1746-
result = handle_ignore_strategy("short", 10, tokenizer_mock)
1741+
"""Test handle_ignore returns None for short prompts."""
1742+
result = ShortPromptStrategyHandler.handle_ignore("short", 10, tokenizer_mock)
17471743
assert result is None
17481744
tokenizer_mock.encode.assert_called_with("short")
17491745

17501746
@pytest.mark.sanity
17511747
def test_handle_ignore_strategy_sufficient_length(self, tokenizer_mock):
1752-
"""Test handle_ignore_strategy returns prompt for sufficient length."""
1753-
result = handle_ignore_strategy("long prompt", 5, tokenizer_mock)
1748+
"""Test handle_ignore returns prompt for sufficient length."""
1749+
result = ShortPromptStrategyHandler.handle_ignore(
1750+
"long prompt", 5, tokenizer_mock
1751+
)
17541752
assert result == "long prompt"
17551753
tokenizer_mock.encode.assert_called_with("long prompt")
17561754

17571755
@pytest.mark.sanity
17581756
def test_handle_concatenate_strategy_enough_prompts(self, tokenizer_mock):
1759-
"""Test handle_concatenate_strategy with enough prompts."""
1757+
"""Test handle_concatenate with enough prompts."""
17601758
dataset_iter = iter([{"prompt": "longer"}])
1761-
result = handle_concatenate_strategy(
1759+
result = ShortPromptStrategyHandler.handle_concatenate(
17621760
"short", 10, dataset_iter, "prompt", tokenizer_mock, "\n"
17631761
)
17641762
assert result == "short\nlonger"
17651763

17661764
@pytest.mark.sanity
17671765
def test_handle_concatenate_strategy_not_enough_prompts(self, tokenizer_mock):
1768-
"""Test handle_concatenate_strategy without enough prompts."""
1766+
"""Test handle_concatenate without enough prompts."""
17691767
dataset_iter: Iterator = iter([])
1770-
result = handle_concatenate_strategy(
1768+
result = ShortPromptStrategyHandler.handle_concatenate(
17711769
"short", 10, dataset_iter, "prompt", tokenizer_mock, ""
17721770
)
17731771
assert result is None
17741772

17751773
@pytest.mark.sanity
17761774
def test_handle_pad_strategy(self, tokenizer_mock):
1777-
"""Test handle_pad_strategy pads short prompts."""
1778-
result = handle_pad_strategy("short", 10, tokenizer_mock, "p")
1775+
"""Test handle_pad pads short prompts."""
1776+
result = ShortPromptStrategyHandler.handle_pad("short", 10, tokenizer_mock, "p")
17791777
assert result.startswith("shortppppp")
17801778

17811779
@pytest.mark.sanity
17821780
def test_handle_error_strategy_valid_prompt(self, tokenizer_mock):
1783-
"""Test handle_error_strategy returns prompt for valid length."""
1784-
result = handle_error_strategy("valid prompt", 5, tokenizer_mock)
1781+
"""Test handle_error returns prompt for valid length."""
1782+
result = ShortPromptStrategyHandler.handle_error(
1783+
"valid prompt", 5, tokenizer_mock
1784+
)
17851785
assert result == "valid prompt"
17861786
tokenizer_mock.encode.assert_called_with("valid prompt")
17871787

17881788
@pytest.mark.sanity
17891789
def test_handle_error_strategy_too_short_prompt(self, tokenizer_mock):
1790-
"""Test handle_error_strategy raises error for short prompts."""
1790+
"""Test handle_error raises error for short prompts."""
17911791
with pytest.raises(PromptTooShortError):
1792-
handle_error_strategy("short", 10, tokenizer_mock)
1792+
ShortPromptStrategyHandler.handle_error("short", 10, tokenizer_mock)
17931793

17941794

17951795
class TestProcessDatasetPushToHub:
@@ -1915,6 +1915,7 @@ def test_strategy_handler_called(
19151915
tmp_path,
19161916
):
19171917
"""Test that strategy handlers are called during dataset processing."""
1918+
from guidellm.data.entrypoints import STRATEGY_HANDLERS
19181919
mock_handler = MagicMock(return_value="processed_prompt")
19191920
with patch.dict(STRATEGY_HANDLERS, {ShortPromptStrategy.IGNORE: mock_handler}):
19201921
# Create a dataset with prompts that need processing

0 commit comments

Comments
 (0)