Skip to content

Commit f61b7f0

Browse files
committed
Added docs
1 parent 40c1118 commit f61b7f0

File tree

1 file changed

+120
-34
lines changed

1 file changed

+120
-34
lines changed

src/guidellm/preprocess/dataset.py

Lines changed: 120 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,47 @@ class ShortPromptStrategy(str, Enum):
3434

3535

3636
def handle_ignore_strategy(
37-
current_prompt: str,
38-
min_prompt_tokens: int,
39-
tokenizer: PreTrainedTokenizerBase,
40-
**_kwargs,
37+
current_prompt: str,
38+
min_prompt_tokens: int,
39+
tokenizer: PreTrainedTokenizerBase,
40+
**_kwargs,
4141
) -> Optional[str]:
42+
"""
43+
Ignores prompts that are shorter than the required minimum token length.
44+
45+
:param current_prompt: The input prompt string.
46+
:param min_prompt_tokens: Minimum required token count.
47+
:param tokenizer: Tokenizer used to count tokens.
48+
:return: The prompt if it meets the length, otherwise None.
49+
"""
50+
4251
if len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
4352
logger.warning("Prompt too short, ignoring")
4453
return None
4554
return current_prompt
4655

4756

4857
def handle_concatenate_strategy(
49-
current_prompt: str,
50-
min_prompt_tokens: int,
51-
dataset_iterator: Iterator[dict[str, Any]],
52-
prompt_column: str,
53-
tokenizer: PreTrainedTokenizerBase,
54-
concat_delimiter: str,
55-
**_kwargs,
58+
current_prompt: str,
59+
min_prompt_tokens: int,
60+
dataset_iterator: Iterator[dict[str, Any]],
61+
prompt_column: str,
62+
tokenizer: PreTrainedTokenizerBase,
63+
concat_delimiter: str,
64+
**_kwargs,
5665
) -> Optional[str]:
66+
"""
67+
Concatenates prompts until the minimum token requirement is met.
68+
69+
:param current_prompt: The initial prompt.
70+
:param min_prompt_tokens: Target minimum token length.
71+
:param dataset_iterator: Iterator to fetch more prompts.
72+
:param prompt_column: Column key for prompt extraction.
73+
:param tokenizer: Tokenizer used to count tokens.
74+
:param concat_delimiter: Delimiter to use between prompts.
75+
:return: Concatenated prompt or None if not enough data.
76+
"""
77+
5778
tokens_len = len(tokenizer.encode(current_prompt))
5879
while tokens_len < min_prompt_tokens:
5980
try:
@@ -69,23 +90,43 @@ def handle_concatenate_strategy(
6990

7091

7192
def handle_pad_strategy(
72-
current_prompt: str,
73-
min_prompt_tokens: int,
74-
tokenizer: PreTrainedTokenizerBase,
75-
pad_char: str,
76-
**_kwargs,
93+
current_prompt: str,
94+
min_prompt_tokens: int,
95+
tokenizer: PreTrainedTokenizerBase,
96+
pad_char: str,
97+
**_kwargs,
7798
) -> str:
99+
"""
100+
Pads the prompt with a character until it reaches the minimum token length.
101+
102+
:param current_prompt: The input prompt.
103+
:param min_prompt_tokens: Desired minimum token count.
104+
:param tokenizer: Tokenizer used to count tokens.
105+
:param pad_char: Character used for padding.
106+
:return: Padded prompt string.
107+
"""
108+
78109
while len(tokenizer.encode(current_prompt)) < min_prompt_tokens:
79110
current_prompt += pad_char
80111
return current_prompt
81112

82113

83114
def handle_error_strategy(
84-
current_prompt: str,
85-
min_prompt_tokens: int,
86-
tokenizer: PreTrainedTokenizerBase,
87-
**_kwargs,
115+
current_prompt: str,
116+
min_prompt_tokens: int,
117+
tokenizer: PreTrainedTokenizerBase,
118+
**_kwargs,
88119
) -> Optional[str]:
120+
"""
121+
Raises an error if the prompt is too short.
122+
123+
:param current_prompt: The input prompt.
124+
:param min_prompt_tokens: Required token count.
125+
:param tokenizer: Tokenizer used to count tokens.
126+
:return: The input prompt if valid.
127+
:raises PromptTooShortError: If the prompt is too short.
128+
"""
129+
89130
prompt_len = len(tokenizer.encode(current_prompt))
90131
if prompt_len < min_prompt_tokens:
91132
raise PromptTooShortError(
@@ -126,6 +167,17 @@ class TokensConfig(BaseModel):
126167

127168
@staticmethod
128169
def parse_str(data: Union[str, Path]) -> "TokensConfig":
170+
"""
171+
Parses a string or path into a TokensConfig object. Supports:
172+
- JSON string
173+
- key=value pairs
174+
- file path to .yaml/.config
175+
176+
:param data: String or path containing configuration.
177+
:return: Parsed TokensConfig instance.
178+
:raises ValueError: If the format is not recognized.
179+
"""
180+
129181
if (
130182
isinstance(data, Path)
131183
or data.strip().endswith(".config")
@@ -169,6 +221,13 @@ def parse_config_file(data: Union[str, Path]) -> "TokensConfig":
169221
return TokensConfig(**config_dict)
170222

171223
def save_dataset_to_file(dataset: Dataset, output_path: Union[str, Path]) -> None:
224+
"""
225+
Saves a HuggingFace Dataset to file in a supported format.
226+
227+
:param dataset: Dataset to save.
228+
:param output_path: Output file path (.json, .jsonl, .csv, .parquet).
229+
:raises ValueError: If the file extension is not supported.
230+
"""
172231
output_path = Path(output_path)
173232
output_path.parent.mkdir(parents=True, exist_ok=True)
174233
suffix = output_path.suffix.lower()
@@ -197,20 +256,39 @@ def _validate_output_suffix(output_path: Union[str, Path]) -> None:
197256

198257

199258
def process_dataset(
200-
data: Union[str, Path],
201-
output_path: Union[str, Path],
202-
processor: Union[str, Path, PreTrainedTokenizerBase],
203-
prompt_tokens: Union[str, Path],
204-
output_tokens: Union[str, Path],
205-
processor_args: Optional[dict[str, Any]] = None,
206-
data_args: Optional[dict[str, Any]] = None,
207-
short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
208-
pad_char: Optional[str] = None,
209-
concat_delimiter: Optional[str] = None,
210-
push_to_hub: bool = False,
211-
hub_dataset_id: Optional[str] = None,
212-
random_seed: int = 42,
259+
data: Union[str, Path],
260+
output_path: Union[str, Path],
261+
processor: Union[str, Path, PreTrainedTokenizerBase],
262+
prompt_tokens: Union[str, Path],
263+
output_tokens: Union[str, Path],
264+
processor_args: Optional[dict[str, Any]] = None,
265+
data_args: Optional[dict[str, Any]] = None,
266+
short_prompt_strategy: ShortPromptStrategy = ShortPromptStrategy.IGNORE,
267+
pad_char: Optional[str] = None,
268+
concat_delimiter: Optional[str] = None,
269+
push_to_hub: bool = False,
270+
hub_dataset_id: Optional[str] = None,
271+
random_seed: int = 42,
213272
) -> None:
273+
"""
274+
Main method to process and save a dataset with sampled prompt/output token counts.
275+
276+
:param data: Path or identifier for dataset input.
277+
:param output_path: File path to save the processed dataset.
278+
:param processor: Tokenizer object or its config.
279+
:param prompt_tokens: Prompt token config string or file.
280+
:param output_tokens: Output token config string or file.
281+
:param processor_args: Optional processor arguments.
282+
:param data_args: Optional data loading arguments.
283+
:param short_prompt_strategy: Strategy for handling short prompts.
284+
:param pad_char: Character used when padding short prompts.
285+
:param concat_delimiter: Delimiter for concatenation strategy.
286+
:param push_to_hub: Whether to push to Hugging Face Hub.
287+
:param hub_dataset_id: Dataset ID on Hugging Face Hub.
288+
:param random_seed: Seed for random sampling.
289+
:raises ValueError: If output path is invalid or pushing conditions unmet.
290+
"""
291+
214292
_validate_output_suffix(output_path)
215293
logger.info(
216294
f"Starting dataset conversion | Input: {data} | "
@@ -300,8 +378,16 @@ def process_dataset(
300378

301379

302380
def push_dataset_to_hub(
303-
hub_dataset_id: Optional[str], processed_dataset: Dataset,
381+
hub_dataset_id: Optional[str], processed_dataset: Dataset,
304382
) -> None:
383+
"""
384+
Pushes the processed dataset to Hugging Face Hub using HF_TOKEN.
385+
386+
:param hub_dataset_id: Identifier on the Hub to push to.
387+
:param processed_dataset: HuggingFace Dataset object.
388+
:raises ValueError: If hub_dataset_id or HF_TOKEN is not available.
389+
"""
390+
305391
hf_token = os.environ.get("HF_TOKEN")
306392
if not hub_dataset_id or not hf_token:
307393
raise ValueError(

0 commit comments

Comments
 (0)