@@ -34,26 +34,47 @@ class ShortPromptStrategy(str, Enum):
3434
3535
3636def handle_ignore_strategy (
37- current_prompt : str ,
38- min_prompt_tokens : int ,
39- tokenizer : PreTrainedTokenizerBase ,
40- ** _kwargs ,
37+ current_prompt : str ,
38+ min_prompt_tokens : int ,
39+ tokenizer : PreTrainedTokenizerBase ,
40+ ** _kwargs ,
4141) -> Optional [str ]:
42+ """
43+ Ignores prompts that are shorter than the required minimum token length.
44+
45+ :param current_prompt: The input prompt string.
46+ :param min_prompt_tokens: Minimum required token count.
47+ :param tokenizer: Tokenizer used to count tokens.
48+ :return: The prompt if it meets the length, otherwise None.
49+ """
50+
4251 if len (tokenizer .encode (current_prompt )) < min_prompt_tokens :
4352 logger .warning ("Prompt too short, ignoring" )
4453 return None
4554 return current_prompt
4655
4756
4857def handle_concatenate_strategy (
49- current_prompt : str ,
50- min_prompt_tokens : int ,
51- dataset_iterator : Iterator [dict [str , Any ]],
52- prompt_column : str ,
53- tokenizer : PreTrainedTokenizerBase ,
54- concat_delimiter : str ,
55- ** _kwargs ,
58+ current_prompt : str ,
59+ min_prompt_tokens : int ,
60+ dataset_iterator : Iterator [dict [str , Any ]],
61+ prompt_column : str ,
62+ tokenizer : PreTrainedTokenizerBase ,
63+ concat_delimiter : str ,
64+ ** _kwargs ,
5665) -> Optional [str ]:
66+ """
67+ Concatenates prompts until the minimum token requirement is met.
68+
69+ :param current_prompt: The initial prompt.
70+ :param min_prompt_tokens: Target minimum token length.
71+ :param dataset_iterator: Iterator to fetch more prompts.
72+ :param prompt_column: Column key for prompt extraction.
73+ :param tokenizer: Tokenizer used to count tokens.
74+ :param concat_delimiter: Delimiter to use between prompts.
75+ :return: Concatenated prompt or None if not enough data.
76+ """
77+
5778 tokens_len = len (tokenizer .encode (current_prompt ))
5879 while tokens_len < min_prompt_tokens :
5980 try :
@@ -69,23 +90,43 @@ def handle_concatenate_strategy(
6990
7091
7192def handle_pad_strategy (
72- current_prompt : str ,
73- min_prompt_tokens : int ,
74- tokenizer : PreTrainedTokenizerBase ,
75- pad_char : str ,
76- ** _kwargs ,
93+ current_prompt : str ,
94+ min_prompt_tokens : int ,
95+ tokenizer : PreTrainedTokenizerBase ,
96+ pad_char : str ,
97+ ** _kwargs ,
7798) -> str :
99+ """
100+ Pads the prompt with a character until it reaches the minimum token length.
101+
102+ :param current_prompt: The input prompt.
103+ :param min_prompt_tokens: Desired minimum token count.
104+ :param tokenizer: Tokenizer used to count tokens.
105+ :param pad_char: Character used for padding.
106+ :return: Padded prompt string.
107+ """
108+
78109 while len (tokenizer .encode (current_prompt )) < min_prompt_tokens :
79110 current_prompt += pad_char
80111 return current_prompt
81112
82113
83114def handle_error_strategy (
84- current_prompt : str ,
85- min_prompt_tokens : int ,
86- tokenizer : PreTrainedTokenizerBase ,
87- ** _kwargs ,
115+ current_prompt : str ,
116+ min_prompt_tokens : int ,
117+ tokenizer : PreTrainedTokenizerBase ,
118+ ** _kwargs ,
88119) -> Optional [str ]:
120+ """
121+ Raises an error if the prompt is too short.
122+
123+ :param current_prompt: The input prompt.
124+ :param min_prompt_tokens: Required token count.
125+ :param tokenizer: Tokenizer used to count tokens.
126+ :return: The input prompt if valid.
127+ :raises PromptTooShortError: If the prompt is too short.
128+ """
129+
89130 prompt_len = len (tokenizer .encode (current_prompt ))
90131 if prompt_len < min_prompt_tokens :
91132 raise PromptTooShortError (
@@ -126,6 +167,17 @@ class TokensConfig(BaseModel):
126167
127168 @staticmethod
128169 def parse_str (data : Union [str , Path ]) -> "TokensConfig" :
170+ """
171+ Parses a string or path into a TokensConfig object. Supports:
172+ - JSON string
173+ - key=value pairs
174+ - file path to .yaml/.config
175+
176+ :param data: String or path containing configuration.
177+ :return: Parsed TokensConfig instance.
178+ :raises ValueError: If the format is not recognized.
179+ """
180+
129181 if (
130182 isinstance (data , Path )
131183 or data .strip ().endswith (".config" )
@@ -169,6 +221,13 @@ def parse_config_file(data: Union[str, Path]) -> "TokensConfig":
169221 return TokensConfig (** config_dict )
170222
171223def save_dataset_to_file (dataset : Dataset , output_path : Union [str , Path ]) -> None :
224+ """
225+ Saves a HuggingFace Dataset to file in a supported format.
226+
227+ :param dataset: Dataset to save.
228+ :param output_path: Output file path (.json, .jsonl, .csv, .parquet).
229+ :raises ValueError: If the file extension is not supported.
230+ """
172231 output_path = Path (output_path )
173232 output_path .parent .mkdir (parents = True , exist_ok = True )
174233 suffix = output_path .suffix .lower ()
@@ -197,20 +256,39 @@ def _validate_output_suffix(output_path: Union[str, Path]) -> None:
197256
198257
199258def process_dataset (
200- data : Union [str , Path ],
201- output_path : Union [str , Path ],
202- processor : Union [str , Path , PreTrainedTokenizerBase ],
203- prompt_tokens : Union [str , Path ],
204- output_tokens : Union [str , Path ],
205- processor_args : Optional [dict [str , Any ]] = None ,
206- data_args : Optional [dict [str , Any ]] = None ,
207- short_prompt_strategy : ShortPromptStrategy = ShortPromptStrategy .IGNORE ,
208- pad_char : Optional [str ] = None ,
209- concat_delimiter : Optional [str ] = None ,
210- push_to_hub : bool = False ,
211- hub_dataset_id : Optional [str ] = None ,
212- random_seed : int = 42 ,
259+ data : Union [str , Path ],
260+ output_path : Union [str , Path ],
261+ processor : Union [str , Path , PreTrainedTokenizerBase ],
262+ prompt_tokens : Union [str , Path ],
263+ output_tokens : Union [str , Path ],
264+ processor_args : Optional [dict [str , Any ]] = None ,
265+ data_args : Optional [dict [str , Any ]] = None ,
266+ short_prompt_strategy : ShortPromptStrategy = ShortPromptStrategy .IGNORE ,
267+ pad_char : Optional [str ] = None ,
268+ concat_delimiter : Optional [str ] = None ,
269+ push_to_hub : bool = False ,
270+ hub_dataset_id : Optional [str ] = None ,
271+ random_seed : int = 42 ,
213272) -> None :
273+ """
274+ Main method to process and save a dataset with sampled prompt/output token counts.
275+
276+ :param data: Path or identifier for dataset input.
277+ :param output_path: File path to save the processed dataset.
278+ :param processor: Tokenizer object or its config.
279+ :param prompt_tokens: Prompt token config string or file.
280+ :param output_tokens: Output token config string or file.
281+ :param processor_args: Optional processor arguments.
282+ :param data_args: Optional data loading arguments.
283+ :param short_prompt_strategy: Strategy for handling short prompts.
284+ :param pad_char: Character used when padding short prompts.
285+ :param concat_delimiter: Delimiter for concatenation strategy.
286+ :param push_to_hub: Whether to push to Hugging Face Hub.
287+ :param hub_dataset_id: Dataset ID on Hugging Face Hub.
288+ :param random_seed: Seed for random sampling.
289+ :raises ValueError: If output path is invalid or pushing conditions unmet.
290+ """
291+
214292 _validate_output_suffix (output_path )
215293 logger .info (
216294 f"Starting dataset conversion | Input: { data } | "
@@ -300,8 +378,16 @@ def process_dataset(
300378
301379
302380def push_dataset_to_hub (
303- hub_dataset_id : Optional [str ], processed_dataset : Dataset ,
381+ hub_dataset_id : Optional [str ], processed_dataset : Dataset ,
304382) -> None :
383+ """
384+ Pushes the processed dataset to Hugging Face Hub using HF_TOKEN.
385+
386+ :param hub_dataset_id: Identifier on the Hub to push to.
387+ :param processed_dataset: HuggingFace Dataset object.
388+ :raises ValueError: If hub_dataset_id or HF_TOKEN is not available.
389+ """
390+
305391 hf_token = os .environ .get ("HF_TOKEN" )
306392 if not hub_dataset_id or not hf_token :
307393 raise ValueError (
0 commit comments