Skip to content

Commit be8001b

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 68c2469 + 66acb4f commit be8001b

File tree

10 files changed

+543
-20
lines changed

10 files changed

+543
-20
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,43 @@ If the dataset size is known to the user, `max_steps` can be calculated as the t
278278
### Example data configs.
279279
280280
We provide some example data configs [here](../tests/artifacts/predefined_data_configs/)
281+
282+
## Offline Data preprocessing
283+
284+
[This script](../scripts/offline_data_processing.py) provides the capability for users to perform standalone data
285+
preprocessing, decoupled from the tuning/training part. It processes raw datasets, performs data preprocessing, and
286+
saves the train and validation datasets (in shards if `--num_dataset_shards` if passed) in parquet format inside the specified `output_dir`.
287+
A data config YAML file can be used to pass configuration to this script. Example command to run this script:
288+
289+
```
290+
python scripts/offline_data_processing.py \
291+
--data_config_path /path/to/data_config.yaml \
292+
--model_name_or_path "model_name" \
293+
--max_seq_length 4096 \
294+
--output_dir /path/to/output/directory \
295+
--log_level info \
296+
--num_dataset_shards 3
297+
```
298+
299+
Example data config file:
300+
301+
```
302+
dataprocessor:
303+
type: default
304+
sampling_stopping_strategy: first_exhausted
305+
seed: 66
306+
datasets:
307+
- name: dataset_1
308+
data_paths:
309+
- tests/artifacts/testdata/jsonl/twitter_complaints_input_output.jsonl
310+
data_handlers:
311+
- name: tokenize_and_apply_input_masking
312+
arguments:
313+
remove_columns: all
314+
batched: false
315+
fn_kwargs:
316+
input_field_name: input
317+
output_field_name: output
318+
```
319+
320+

scripts/offline_data_processing.py

Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Standard
2+
import logging
3+
import os
4+
import sys
5+
import traceback
6+
7+
# Third Party
8+
from transformers import (
9+
AutoTokenizer,
10+
GPT2Tokenizer,
11+
GPTNeoXTokenizerFast,
12+
LlamaTokenizer,
13+
LlamaTokenizerFast,
14+
)
15+
16+
# Local
17+
from tuning.config import configs
18+
from tuning.data.setup_dataprocessor import process_dataargs
19+
from tuning.sft_trainer import get_parser
20+
from tuning.utils.error_logging import USER_ERROR_EXIT_CODE, write_termination_log
21+
from tuning.utils.logging import set_log_level
22+
23+
24+
def save_dataset_shards(
25+
dataset, output_dir: str, num_shards: int, dataset_name: str
26+
) -> None:
27+
"""
28+
Saves the given dataset in the specified number of shards.
29+
30+
Args:
31+
dataset: The dataset to shard and save.
32+
output_dir (str): Directory to save the dataset shards.
33+
num_shards (int): Number of shards to create.
34+
dataset_name (str): Name of the dataset (used for logging).
35+
"""
36+
os.makedirs(output_dir, exist_ok=True)
37+
for shard_idx in range(num_shards):
38+
shard = dataset.shard(index=shard_idx, num_shards=num_shards)
39+
shard_path = os.path.join(output_dir, f"ds_{shard_idx:05d}.parquet")
40+
shard.to_parquet(shard_path)
41+
logging.info("Dumped %d shards of %s at %s", num_shards, dataset_name, output_dir)
42+
43+
44+
def get_processed_dataset(
45+
model_args: configs.ModelArguments,
46+
data_args: configs.DataArguments,
47+
train_args: configs.TrainingArguments,
48+
):
49+
"""
50+
Processes the dataset based on data config yaml.
51+
52+
Args:
53+
model_args (configs.ModelArguments): Model configuration arguments.
54+
data_args (configs.DataArguments): Data configuration arguments.
55+
train_args (configs.TrainingArguments): Training configuration arguments.
56+
57+
Returns:
58+
tuple: A tuple containing the formatted training dataset and validation dataset.
59+
"""
60+
# Set log level for this function
61+
train_args, logger = set_log_level(train_args, "get_processed_dataset")
62+
63+
logger.info(
64+
"Starting dataset processing with model_args: %s, data_args: %s, training_args: %s",
65+
model_args,
66+
data_args,
67+
train_args,
68+
)
69+
70+
# Load tokenizer for the model
71+
tokenizer_path = model_args.tokenizer_name_or_path or model_args.model_name_or_path
72+
logger.debug("Loading tokenizer from %s", tokenizer_path)
73+
tokenizer = AutoTokenizer.from_pretrained(
74+
tokenizer_path,
75+
cache_dir=train_args.cache_dir,
76+
use_fast=True,
77+
legacy=True,
78+
)
79+
logger.debug("Tokenizer loaded successfully.")
80+
81+
# Add chat_template to the tokenizer if provided
82+
if data_args.chat_template:
83+
data_args.chat_template = data_args.chat_template.replace(r"\n", "\n")
84+
85+
logger.info("Adding chat_template to the tokenizer")
86+
if tokenizer.chat_template:
87+
logger.warning(
88+
"replacing existing chat_template %s with the given chat_template %s",
89+
tokenizer.chat_template,
90+
data_args.chat_template,
91+
)
92+
tokenizer.chat_template = data_args.chat_template
93+
94+
# Prepare special tokens dictionary
95+
special_tokens_dict = {}
96+
if not model_args.tokenizer_name_or_path:
97+
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
98+
special_tokens_dict["bos_token"] = "<s>"
99+
special_tokens_dict["eos_token"] = "</s>"
100+
special_tokens_dict["unk_token"] = "<unk>"
101+
special_tokens_dict["pad_token"] = "<pad>"
102+
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
103+
special_tokens_dict["pad_token"] = "<pad>"
104+
105+
if tokenizer.pad_token is None:
106+
logger.warning(
107+
"PAD token not found in tokenizer; setting PAD token to default."
108+
)
109+
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
110+
if tokenizer.eos_token is None:
111+
logger.warning(
112+
"EOS token not found in tokenizer; setting EOS token to default."
113+
)
114+
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
115+
if tokenizer.pad_token == tokenizer.eos_token:
116+
logger.warning(
117+
"PAD token and EOS token are the same. Overriding accordingly."
118+
)
119+
if tokenizer.eos_token != configs.DEFAULT_PAD_TOKEN:
120+
tokenizer.pad_token = configs.DEFAULT_PAD_TOKEN
121+
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
122+
else:
123+
tokenizer.eos_token = configs.DEFAULT_EOS_TOKEN
124+
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
125+
126+
# adds user specified special tokens to vocab
127+
if data_args.add_special_tokens:
128+
logger.info(
129+
"Adding user-defined special tokens: %s ", data_args.add_special_tokens
130+
)
131+
special_tokens_dict["additional_special_tokens"] = data_args.add_special_tokens
132+
133+
if special_tokens_dict:
134+
logger.info("Adding special tokens: %s", special_tokens_dict)
135+
tokenizer.add_special_tokens(special_tokens_dict)
136+
137+
# Process data using the provided arguments and tokenizer
138+
logger.info("Calling process_dataargs to format datasets.")
139+
(
140+
formatted_train_dataset,
141+
formatted_validation_dataset,
142+
_,
143+
_,
144+
_,
145+
_,
146+
) = process_dataargs(data_args, tokenizer, train_args)
147+
logger.info("Dataset processing completed successfully.")
148+
149+
return formatted_train_dataset, formatted_validation_dataset
150+
151+
152+
def main():
153+
"""
154+
Main function that parses arguments, processes datasets, and saves the output.
155+
"""
156+
logger = logging.getLogger()
157+
logger.info("Starting Data Processing script execution.")
158+
159+
parser = get_parser()
160+
parser.add_argument(
161+
"--num_dataset_shards",
162+
type=int,
163+
default=1,
164+
help="Number of shards to be used for saving the dataset.",
165+
)
166+
167+
try:
168+
parsed_output = parser.parse_args_into_dataclasses()
169+
# Extract arguments based on type
170+
arg_types = {
171+
configs.ModelArguments: "model_args",
172+
configs.DataArguments: "data_args",
173+
configs.TrainingArguments: "training_args",
174+
}
175+
args = {key: None for key in arg_types.values()}
176+
for item in parsed_output:
177+
for arg_class, key in arg_types.items():
178+
if isinstance(item, arg_class):
179+
args[key] = item
180+
181+
# Extract additional namespace argument
182+
num_dataset_shards = next(
183+
(
184+
item.num_dataset_shards
185+
for item in parsed_output
186+
if hasattr(item, "num_dataset_shards")
187+
),
188+
1,
189+
)
190+
191+
if None in args.values():
192+
raise ValueError(
193+
"One of the arguments is None. Please check the arguments passed."
194+
)
195+
196+
logger.debug(
197+
"Input args parsed:\n model_args: %s\n data_args: %s\n training_args: %s\n Shards: %d",
198+
args["model_args"],
199+
args["data_args"],
200+
args["training_args"],
201+
num_dataset_shards,
202+
)
203+
args["training_args"], logger = set_log_level(args["training_args"], __name__)
204+
except Exception as e: # pylint: disable=broad-exception-caught
205+
logger.error("Error parsing arguments: %s", traceback.format_exc())
206+
write_termination_log(f"Exception raised during argument parsing: {e}")
207+
sys.exit(USER_ERROR_EXIT_CODE)
208+
209+
try:
210+
logger.info("Processing dataset.")
211+
formatted_train_dataset, formatted_validation_dataset = get_processed_dataset(
212+
model_args=args["model_args"],
213+
data_args=args["data_args"],
214+
train_args=args["training_args"],
215+
)
216+
except Exception as e: # pylint: disable=broad-exception-caught
217+
logger.error("Error processing dataset: %s", traceback.format_exc())
218+
write_termination_log(f"Exception raised during dataset processing: {e}")
219+
sys.exit(USER_ERROR_EXIT_CODE)
220+
221+
# Save train dataset shards
222+
train_dataset_dir = os.path.join(args["training_args"].output_dir, "train_dataset")
223+
logging.info(
224+
"Trying to dump %d shards of train dataset at %s",
225+
num_dataset_shards,
226+
train_dataset_dir,
227+
)
228+
if formatted_train_dataset is not None:
229+
save_dataset_shards(
230+
formatted_train_dataset,
231+
train_dataset_dir,
232+
num_dataset_shards,
233+
"train_dataset",
234+
)
235+
else:
236+
logging.warning("Train dataset is None. Not saving train dataset.")
237+
238+
# Save validation dataset shards
239+
validation_dataset_dir = os.path.join(
240+
args["training_args"].output_dir, "validation_dataset"
241+
)
242+
logging.info(
243+
"Trying to dump %d shards of validation dataset at %s",
244+
num_dataset_shards,
245+
validation_dataset_dir,
246+
)
247+
if formatted_validation_dataset is not None:
248+
save_dataset_shards(
249+
formatted_validation_dataset,
250+
validation_dataset_dir,
251+
num_dataset_shards,
252+
"validation_dataset",
253+
)
254+
else:
255+
logging.warning("Validation dataset is None. Not saving validation dataset.")
256+
257+
logger.info(
258+
"Data Processing script execution completed. Data saved in %s directory",
259+
args["training_args"].output_dir,
260+
)
261+
262+
263+
if __name__ == "__main__":
264+
main()

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
DATA_CONFIG_MULTITURN_DATA_YAML = os.path.join(
3838
PREDEFINED_DATA_CONFIGS, "multi_turn_data_with_chat_template.yaml"
3939
)
40+
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML = os.path.join(
41+
PREDEFINED_DATA_CONFIGS, "multi_turn_data_with_chat_template_granite_3_1B.yaml"
42+
)
4043
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
4144
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
4245
)

tests/artifacts/predefined_data_configs/multi_turn_data_with_chat_template.yaml

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,24 @@ datasets:
1616
data_handlers:
1717
- name: apply_tokenizer_chat_template
1818
arguments:
19+
remove_columns: all
1920
fn_kwargs:
20-
dataset_text_field: formatted_chat_data
21+
dataset_text_field: "formatted_chat_data"
22+
- name: dataset_2
23+
data_paths:
24+
- "FILE_PATH"
25+
data_handlers:
26+
- name: apply_tokenizer_chat_template
27+
arguments:
28+
remove_columns: all
29+
fn_kwargs:
30+
dataset_text_field: "formatted_chat_data"
31+
- name: dataset_3
32+
data_paths:
33+
- "FILE_PATH"
34+
data_handlers:
35+
- name: apply_tokenizer_chat_template
36+
arguments:
37+
remove_columns: all
38+
fn_kwargs:
39+
dataset_text_field: "formatted_chat_data"

0 commit comments

Comments
 (0)