22from collections .abc import Callable , Iterator
33from enum import Enum
44from pathlib import Path
5- from typing import Any
5+ from typing import Any , cast
66
77from datasets import Dataset
88from loguru import logger
@@ -29,120 +29,136 @@ class ShortPromptStrategy(str, Enum):
2929 ERROR = "error"
3030
3131
32- def handle_ignore_strategy (
33- current_prompt : str ,
34- min_prompt_tokens : int ,
35- tokenizer : PreTrainedTokenizerBase ,
36- ** _kwargs ,
37- ) -> str | None :
38- """
39- Ignores prompts that are shorter than the required minimum token length.
32+ class ShortPromptStrategyHandler :
33+ """Handler class for short prompt strategies."""
4034
41- :param current_prompt: The input prompt string.
42- :param min_prompt_tokens: Minimum required token count.
43- :param tokenizer: Tokenizer used to count tokens.
44- :return: The prompt if it meets the length, otherwise None.
45- """
35+ @staticmethod
36+ def handle_ignore (
37+ current_prompt : str ,
38+ min_prompt_tokens : int ,
39+ tokenizer : PreTrainedTokenizerBase ,
40+ ** _kwargs ,
41+ ) -> str | None :
42+ """
43+ Ignores prompts that are shorter than the required minimum token length.
4644
47- if len (tokenizer .encode (current_prompt )) < min_prompt_tokens :
48- logger .warning ("Prompt too short, ignoring" )
49- return None
50- return current_prompt
45+ :param current_prompt: The input prompt string.
46+ :param min_prompt_tokens: Minimum required token count.
47+ :param tokenizer: Tokenizer used to count tokens.
48+ :return: The prompt if it meets the length, otherwise None.
49+ """
5150
52-
53- def handle_concatenate_strategy (
54- current_prompt : str ,
55- min_prompt_tokens : int ,
56- dataset_iterator : Iterator [dict [str , Any ]],
57- prompt_column : str ,
58- tokenizer : PreTrainedTokenizerBase ,
59- concat_delimiter : str ,
60- ** _kwargs ,
61- ) -> str | None :
62- """
63- Concatenates prompts until the minimum token requirement is met.
64-
65- :param current_prompt: The initial prompt.
66- :param min_prompt_tokens: Target minimum token length.
67- :param dataset_iterator: Iterator to fetch more prompts.
68- :param prompt_column: Column key for prompt extraction.
69- :param tokenizer: Tokenizer used to count tokens.
70- :param concat_delimiter: Delimiter to use between prompts.
71- :return: Concatenated prompt or None if not enough data.
72- """
73-
74- tokens_len = len (tokenizer .encode (current_prompt ))
75- while tokens_len < min_prompt_tokens :
76- try :
77- next_row = next (dataset_iterator )
78- except StopIteration :
79- logger .warning (
80- "Could not concatenate enough prompts to reach minimum length, ignoring"
81- )
51+ if len (tokenizer .encode (current_prompt )) < min_prompt_tokens :
52+ logger .warning ("Prompt too short, ignoring" )
8253 return None
83- current_prompt += concat_delimiter + next_row [prompt_column ]
84- tokens_len = len (tokenizer .encode (current_prompt ))
85- return current_prompt
54+ return current_prompt
55+
56+ @staticmethod
57+ def handle_concatenate (
58+ current_prompt : str ,
59+ min_prompt_tokens : int ,
60+ dataset_iterator : Iterator [dict [str , Any ]],
61+ prompt_column : str ,
62+ tokenizer : PreTrainedTokenizerBase ,
63+ concat_delimiter : str ,
64+ ** _kwargs ,
65+ ) -> str | None :
66+ """
67+ Concatenates prompts until the minimum token requirement is met.
68+
69+ :param current_prompt: The initial prompt.
70+ :param min_prompt_tokens: Target minimum token length.
71+ :param dataset_iterator: Iterator to fetch more prompts.
72+ :param prompt_column: Column key for prompt extraction.
73+ :param tokenizer: Tokenizer used to count tokens.
74+ :param concat_delimiter: Delimiter to use between prompts.
75+ :return: Concatenated prompt or None if not enough data.
76+ """
8677
78+ tokens_len = len (tokenizer .encode (current_prompt ))
79+ while tokens_len < min_prompt_tokens :
80+ try :
81+ next_row = next (dataset_iterator )
82+ except StopIteration :
83+ logger .warning (
84+ "Could not concatenate enough prompts to reach minimum "
85+ "length, ignoring"
86+ )
87+ return None
88+ current_prompt += concat_delimiter + next_row [prompt_column ]
89+ tokens_len = len (tokenizer .encode (current_prompt ))
90+ return current_prompt
91+
92+ @staticmethod
93+ def handle_pad (
94+ current_prompt : str ,
95+ min_prompt_tokens : int ,
96+ tokenizer : PreTrainedTokenizerBase ,
97+ pad_char : str ,
98+ pad_multiplier : int = 2 ,
99+ ** _kwargs ,
100+ ) -> str :
101+ """
102+ Pads the prompt with a character until it reaches the minimum token length.
103+
104+ :param current_prompt: The input prompt.
105+ :param min_prompt_tokens: Desired minimum token count.
106+ :param tokenizer: Tokenizer used to count tokens.
107+ :param pad_char: Character used for padding.
108+ :param pad_multiplier: Multiplier for padding character length.
109+ :return: Padded prompt string.
110+ """
111+ tokens = tokenizer .encode (current_prompt )
112+ pad_count = 1
113+ prompt = current_prompt
114+ while len (tokens ) < min_prompt_tokens :
115+ prompt += pad_char * pad_count
116+ tokens = tokenizer .encode (prompt )
117+ pad_count *= pad_multiplier
118+ return prompt
119+
120+ @staticmethod
121+ def handle_error (
122+ current_prompt : str ,
123+ min_prompt_tokens : int ,
124+ tokenizer : PreTrainedTokenizerBase ,
125+ ** _kwargs ,
126+ ) -> str | None :
127+ """
128+ Raises an error if the prompt is too short.
129+
130+ :param current_prompt: The input prompt.
131+ :param min_prompt_tokens: Required token count.
132+ :param tokenizer: Tokenizer used to count tokens.
133+ :return: The input prompt if valid.
134+ :raises PromptTooShortError: If the prompt is too short.
135+ """
136+
137+ prompt_len = len (tokenizer .encode (current_prompt ))
138+ if prompt_len < min_prompt_tokens :
139+ raise PromptTooShortError (
140+ f"Found too short prompt: { current_prompt } , with length: { prompt_len } . "
141+ f"Minimum length required: { min_prompt_tokens } ." ,
142+ )
143+ return current_prompt
87144
88- def handle_pad_strategy (
89- current_prompt : str ,
90- min_prompt_tokens : int ,
91- tokenizer : PreTrainedTokenizerBase ,
92- pad_char : str ,
93- pad_multiplier : int = 2 ,
94- ** _kwargs ,
95- ) -> str :
96- """
97- Pads the prompt with a character until it reaches the minimum token length.
98-
99- :param current_prompt: The input prompt.
100- :param min_prompt_tokens: Desired minimum token count.
101- :param tokenizer: Tokenizer used to count tokens.
102- :param pad_char: Character used for padding.
103- :param pad_multiplier: Multiplier for padding character length.
104- :return: Padded prompt string.
105- """
106- tokens = tokenizer .encode (current_prompt )
107- pad_count = 1
108- prompt = current_prompt
109- while len (tokens ) < min_prompt_tokens :
110- prompt += pad_char * pad_count
111- tokens = tokenizer .encode (prompt )
112- pad_count *= pad_multiplier
113- return prompt
114-
115-
116- def handle_error_strategy (
117- current_prompt : str ,
118- min_prompt_tokens : int ,
119- tokenizer : PreTrainedTokenizerBase ,
120- ** _kwargs ,
121- ) -> str | None :
122- """
123- Raises an error if the prompt is too short.
124-
125- :param current_prompt: The input prompt.
126- :param min_prompt_tokens: Required token count.
127- :param tokenizer: Tokenizer used to count tokens.
128- :return: The input prompt if valid.
129- :raises PromptTooShortError: If the prompt is too short.
130- """
145+ @classmethod
146+ def get_strategy_handler (cls , strategy : ShortPromptStrategy ) -> Callable [..., Any ]:
147+ """
148+ Get the handler for a specific strategy.
131149
132- prompt_len = len (tokenizer .encode (current_prompt ))
133- if prompt_len < min_prompt_tokens :
134- raise PromptTooShortError (
135- f"Found too short prompt: { current_prompt } , with length: { prompt_len } . "
136- f"Minimum length required: { min_prompt_tokens } ." ,
137- )
138- return current_prompt
150+ :param strategy: The short prompt strategy to get the handler for.
151+ :return: The handler callable for the specified strategy.
152+ """
153+ return cast ("Callable[..., Any]" , STRATEGY_HANDLERS [strategy ])
139154
140155
141- STRATEGY_HANDLERS : dict [ShortPromptStrategy , Callable ] = {
142- ShortPromptStrategy .IGNORE : handle_ignore_strategy ,
143- ShortPromptStrategy .CONCATENATE : handle_concatenate_strategy ,
144- ShortPromptStrategy .PAD : handle_pad_strategy ,
145- ShortPromptStrategy .ERROR : handle_error_strategy ,
156+ # Initialize STRATEGY_HANDLERS after class definition to allow method references
157+ STRATEGY_HANDLERS = {
158+ ShortPromptStrategy .IGNORE : ShortPromptStrategyHandler .handle_ignore ,
159+ ShortPromptStrategy .CONCATENATE : ShortPromptStrategyHandler .handle_concatenate ,
160+ ShortPromptStrategy .PAD : ShortPromptStrategyHandler .handle_pad ,
161+ ShortPromptStrategy .ERROR : ShortPromptStrategyHandler .handle_error ,
146162}
147163
148164
@@ -245,7 +261,9 @@ def process_dataset(
245261 )
246262
247263 # Setup column mapper
248- column_mapper = GenerativeColumnMapper (column_mappings = data_column_mapper ) # type: ignore[arg-type]
264+ column_mapper = GenerativeColumnMapper (
265+ column_mappings = data_column_mapper # type: ignore[arg-type]
266+ )
249267 column_mapper .setup_data (
250268 datasets = [dataset ],
251269 data_args = [data_args or {}],
@@ -265,7 +283,9 @@ def process_dataset(
265283 # Process dataset
266284 dataset_iterator = iter (dataset )
267285 processed_prompts = []
268- prompt_handler = STRATEGY_HANDLERS [short_prompt_strategy ]
286+ prompt_handler = ShortPromptStrategyHandler .get_strategy_handler (
287+ short_prompt_strategy
288+ )
269289
270290 for row in dataset_iterator :
271291 processed_row = _process_single_row (
0 commit comments