11import asyncio
2+ import codecs
23import json
34from pathlib import Path
45from typing import get_args
@@ -280,6 +281,18 @@ def benchmark(
280281 )
281282 )
282283
284+ def decode_escaped_str (_ctx , _param , value ):
285+ """
286+ Click auto adds characters. For example, when using --pad-char "\n ",
287+ it parses it as "\\ n". This method decodes the string to handle escape
288+ sequences correctly.
289+ """
290+ if value is None :
291+ return None
292+ try :
293+ return codecs .decode (value , "unicode_escape" )
294+ except Exception as e :
295+ raise click .BadParameter (f"Could not decode escape sequences: { e } " ) from e
283296
284297@cli .command (
285298 help = (
@@ -291,27 +304,26 @@ def config():
291304 print_config ()
292305
293306
294- @cli .group (help = "Preprocessing utilities for datasets ." )
307+ @cli .group (help = "General preprocessing tools and utilities ." )
295308def preprocess ():
296309 pass
297310
298311
299312@preprocess .command (
300- help = "Convert a dataset to have specific prompt and output token sizes.\n \n "
301- "INPUT_DATA: Path to the input dataset or dataset ID.\n "
302- "OUTPUT_PATH: Directory to save the converted dataset. "
303- "The dataset will be saved as an Arrow dataset (.arrow) inside the directory."
313+ help = (
314+ "Convert a dataset to have specific prompt and output token sizes.\n \n "
315+ "INPUT_DATA: Path to the input dataset or dataset ID.\n "
316+ "OUTPUT_PATH: Path to save the converted dataset, including file suffix. "
317+ )
304318)
305319@click .argument (
306- "input_data " ,
320+ "data " ,
307321 type = str ,
308- metavar = "INPUT_DATA" ,
309322 required = True ,
310323)
311324@click .argument (
312325 "output_path" ,
313326 type = click .Path (file_okay = True , dir_okay = False , writable = True , resolve_path = True ),
314- metavar = "OUTPUT_PATH" ,
315327 required = True ,
316328)
317329@click .option (
@@ -348,11 +360,21 @@ def preprocess():
348360 help = "Strategy to handle prompts shorter than the target length. " ,
349361)
350362@click .option (
351- "--pad-token " ,
363+ "--pad-char " ,
352364 type = str ,
353- default = None ,
365+ default = "" ,
366+ callback = decode_escaped_str ,
354367 help = "The token to pad short prompts with when using the 'pad' strategy." ,
355368)
369+ @click .option (
370+ "--concat-delimiter" ,
371+ type = str ,
372+ default = "" ,
373+ help = (
374+ "The delimiter to use when concatenating prompts that are too short."
375+ " Used when strategy is 'concatenate'."
376+ )
377+ )
356378@click .option (
357379 "--prompt-tokens-average" ,
358380 type = int ,
@@ -378,13 +400,6 @@ def preprocess():
378400 default = None ,
379401 help = "Maximum number of prompt tokens." ,
380402)
381- @click .option (
382- "--prompt-random-seed" ,
383- type = int ,
384- default = 42 ,
385- show_default = True ,
386- help = "Random seed for prompt token sampling." ,
387- )
388403@click .option (
389404 "--output-tokens-average" ,
390405 type = int ,
@@ -410,13 +425,6 @@ def preprocess():
410425 default = None ,
411426 help = "Maximum number of output tokens." ,
412427)
413- @click .option (
414- "--output-random-seed" ,
415- type = int ,
416- default = 123 ,
417- show_default = True ,
418- help = "Random seed for output token sampling." ,
419- )
420428@click .option (
421429 "--push-to-hub" ,
422430 is_flag = True ,
@@ -429,47 +437,54 @@ def preprocess():
429437 help = "The Hugging Face Hub dataset ID to push to. "
430438 "Required if --push-to-hub is used." ,
431439)
440+ @click .option (
441+ "--random-seed" ,
442+ type = int ,
443+ default = 42 ,
444+ show_default = True ,
445+ help = "Random seed for prompt token sampling and output tokens sampling." ,
446+ )
432447def dataset (
433- input_data ,
448+ data ,
434449 output_path ,
435450 processor ,
436451 processor_args ,
437452 data_args ,
438453 short_prompt_strategy ,
439- pad_token ,
454+ pad_char ,
455+ concat_delimiter ,
440456 prompt_tokens_average ,
441457 prompt_tokens_stdev ,
442458 prompt_tokens_min ,
443459 prompt_tokens_max ,
444- prompt_random_seed ,
445460 output_tokens_average ,
446461 output_tokens_stdev ,
447462 output_tokens_min ,
448463 output_tokens_max ,
449- output_random_seed ,
450464 push_to_hub ,
451465 hub_dataset_id ,
466+ random_seed ,
452467):
453468 process_dataset (
454- input_data = input_data ,
469+ data = data ,
455470 output_path = output_path ,
456471 processor = processor ,
457472 processor_args = processor_args ,
458473 data_args = data_args ,
459474 short_prompt_strategy = short_prompt_strategy ,
460- pad_token = pad_token ,
475+ pad_char = pad_char ,
476+ concat_delimiter = concat_delimiter ,
461477 prompt_tokens_average = prompt_tokens_average ,
462478 prompt_tokens_stdev = prompt_tokens_stdev ,
463479 prompt_tokens_min = prompt_tokens_min ,
464480 prompt_tokens_max = prompt_tokens_max ,
465- prompt_random_seed = prompt_random_seed ,
466481 output_tokens_average = output_tokens_average ,
467482 output_tokens_stdev = output_tokens_stdev ,
468483 output_tokens_min = output_tokens_min ,
469484 output_tokens_max = output_tokens_max ,
470- output_random_seed = output_random_seed ,
471485 push_to_hub = push_to_hub ,
472486 hub_dataset_id = hub_dataset_id ,
487+ random_seed = random_seed ,
473488 )
474489
475490
0 commit comments