Skip to content

Commit b82549e

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 6757c0b + 14f2f24 commit b82549e

File tree

11 files changed

+437
-83
lines changed

11 files changed

+437
-83
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ This library currently supports the following [preexisting data handlers](https:
233233
Uses a tokenizer's chat template to preprocess dataset elements, good for single/multi turn chat templates.
234234
- `duplicate_columns`:
235235
Duplicates one column of the dataset to another column.
236+
- `tokenize`:
237+
Tokenizes one column of the dataset passed as input `dataset_text_field`.
236238

237239
These handlers could be requested by their same name and users can lookup the function args from [here](https://github.com/foundation-model-stack/fms-hf-tuning/blob/main/tuning/data/data_handlers.py)
238240

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ classifiers=[
2929
dependencies = [
3030
"numpy>=1.26.4,<2.0",
3131
"accelerate>=0.20.3,!=0.34,<1.1",
32-
"transformers>=4.46,<4.48.2",
32+
"transformers>=4.49,<5.0",
3333
"torch>=2.2.0,<2.5",
3434
"sentencepiece>=0.1.99,<0.3",
3535
"tokenizers>=0.13.3,<1.0",

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,9 @@
4949
DATA_CONFIG_RENAME_RETAIN_COLUMNS = os.path.join(
5050
PREDEFINED_DATA_CONFIGS, "rename_retain_columns.yaml"
5151
)
52+
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER = os.path.join(
53+
PREDEFINED_DATA_CONFIGS, "tokenize_using_handler_and_train.yaml"
54+
)
55+
DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER = os.path.join(
56+
PREDEFINED_DATA_CONFIGS, "skip_large_text_data_handler_template.yaml"
57+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
dataprocessor:
2+
type: default
3+
datasets:
4+
- name: pre_tokenized
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize
9+
arguments:
10+
remove_columns: all
11+
batched: true
12+
fn_kwargs:
13+
dataset_text_field: "output"
14+
- name: duplicate_columns
15+
arguments:
16+
remove_columns: all
17+
batched: true
18+
fn_kwargs:
19+
old_column: "input_ids"
20+
new_column: "labels"
21+
- name: skip_large_text
22+
arguments:
23+
fn_kwargs:
24+
column_name: "input_ids"
25+
max_length: 50
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataprocessor:
2+
type: default
3+
datasets:
4+
- name: non_tokenized_dataset
5+
data_paths:
6+
- "FILE_PATH"
7+
data_handlers:
8+
- name: tokenize
9+
arguments:
10+
remove_columns: all
11+
batched: true
12+
fn_kwargs:
13+
dataset_text_field: "output"
14+
truncation: True
15+
max_length: 1024
16+
- name: duplicate_columns
17+
arguments:
18+
remove_columns: all
19+
batched: true
20+
fn_kwargs:
21+
old_column: "input_ids"
22+
new_column: "labels"

tests/data/test_data_handlers.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# https://spdx.dev/learn/handling-license-info/
1717

1818
# Third Party
19-
from datasets import IterableDatasetDict
19+
from datasets import Dataset, IterableDatasetDict
2020
from transformers import AutoTokenizer
2121
import datasets
2222
import pytest
@@ -35,7 +35,10 @@
3535
apply_custom_jinja_template,
3636
combine_sequence,
3737
duplicate_columns,
38+
skip_large_text,
39+
tokenize,
3840
)
41+
from tuning.data.setup_dataprocessor import is_pretokenized_dataset
3942

4043

4144
def test_apply_custom_formatting_template():
@@ -250,3 +253,62 @@ def test_duplicate_columns_copies_columns():
250253
assert new in first_element
251254
assert old in first_element
252255
assert first_element[new] == first_element[old]
256+
257+
258+
def test_tokenizer_data_handler_tokenizes():
259+
"Ensure tokenizer data handler tokenizes the input properly with proper truncation"
260+
d = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA_JSONL)
261+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
262+
dataset_text_field = "output"
263+
truncation = True
264+
max_length = 10
265+
266+
updated_dataaset = d.map(
267+
tokenize,
268+
fn_kwargs={
269+
"tokenizer": tokenizer,
270+
"dataset_text_field": dataset_text_field,
271+
"truncation": truncation,
272+
"max_length": max_length,
273+
},
274+
)
275+
276+
assert "input_ids" in updated_dataaset["train"][0]
277+
for element in updated_dataaset["train"]:
278+
assert len(element["input_ids"]) <= max_length
279+
280+
281+
@pytest.mark.parametrize(
282+
"column_name, max_length",
283+
[
284+
(None, None),
285+
("input_ids", None),
286+
(1024, 1024),
287+
("not_existing", "not_existing"),
288+
],
289+
)
290+
def test_skip_large_text_handler_throws_error_on_bad_args(column_name, max_length):
291+
"Ensure that skip large text handler throws error on bad arguments"
292+
d = datasets.load_dataset("json", data_files=TWITTER_COMPLAINTS_DATA_JSONL)
293+
fn_kwargs = {}
294+
fn_kwargs["column_name"] = column_name
295+
fn_kwargs["max_length"] = max_length
296+
297+
with pytest.raises(ValueError):
298+
filtered = d.filter(skip_large_text, fn_kwargs=fn_kwargs)
299+
300+
301+
def test_skip_large_text_handler():
302+
"Ensure that skip large text handler skips dataset as intended"
303+
304+
def test_dataset_generator():
305+
for i in range(0, 100):
306+
yield {"input": list(range(0, i + 1))}
307+
308+
d = Dataset.from_generator(test_dataset_generator)
309+
fn_kwargs = {}
310+
fn_kwargs["column_name"] = "input"
311+
fn_kwargs["max_length"] = 61
312+
313+
filtered = d.filter(skip_large_text, fn_kwargs=fn_kwargs)
314+
assert len(filtered) == 60

tests/test_sft_trainer.py

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
4242
DATA_CONFIG_MULTITURN_DATA_YAML,
4343
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
44+
DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER,
4445
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
46+
DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER,
4547
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT,
4648
DATA_CONFIG_YAML_STREAMING_PRETOKENIZED,
4749
)
@@ -78,7 +80,11 @@
7880
DataPreProcessorConfig,
7981
DataSetConfig,
8082
)
81-
from tuning.data.data_handlers import add_tokenizer_eos_token
83+
from tuning.data.data_handlers import (
84+
DataHandler,
85+
DataHandlerType,
86+
add_tokenizer_eos_token,
87+
)
8288

8389
MODEL_ARGS = configs.ModelArguments(
8490
model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32"
@@ -321,14 +327,6 @@ def _get_training_logs_by_epoch(dir_path: str, epoch: int = None):
321327
return data_list
322328

323329

324-
def test_run_train_requires_output_dir():
325-
"""Check fails when output dir not provided."""
326-
updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS)
327-
updated_output_dir_train_args.output_dir = None
328-
with pytest.raises(TypeError):
329-
sft_trainer.train(MODEL_ARGS, DATA_ARGS, updated_output_dir_train_args, None)
330-
331-
332330
def test_run_train_fails_training_data_path_not_exist():
333331
"""Check fails when data path not found."""
334332
updated_data_path_args = copy.deepcopy(DATA_ARGS)
@@ -996,6 +994,97 @@ def test_run_training_with_pretokenised_dataset_containing_input_ids():
996994
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
997995

998996

997+
def test_run_training_with_data_tokenized_using_tokenizer_handler():
998+
"""Ensure that we can train on non tokenized dataset works by tokenizing using
999+
tokenizer data handler via data config."""
1000+
with tempfile.TemporaryDirectory() as tempdir:
1001+
1002+
data_args = copy.deepcopy(DATA_ARGS)
1003+
1004+
# set training_data_path and response_template to none
1005+
data_args.response_template = None
1006+
data_args.training_data_path = None
1007+
1008+
dataconfigfile = DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER
1009+
datapath = TWITTER_COMPLAINTS_DATA_JSONL
1010+
1011+
# add data_paths in data_config file
1012+
with tempfile.NamedTemporaryFile(
1013+
"w", delete=False, suffix=".yaml"
1014+
) as temp_yaml_file:
1015+
with open(dataconfigfile, "r", encoding="utf-8") as f:
1016+
data = yaml.safe_load(f)
1017+
datasets = data["datasets"]
1018+
for _, d in enumerate(datasets):
1019+
d["data_paths"] = [datapath]
1020+
yaml.dump(data, temp_yaml_file)
1021+
data_args.data_config_path = temp_yaml_file.name
1022+
1023+
train_args = copy.deepcopy(TRAIN_ARGS)
1024+
train_args.output_dir = tempdir
1025+
1026+
sft_trainer.train(MODEL_ARGS, data_args, train_args)
1027+
1028+
# validate full ft configs
1029+
_validate_training(tempdir)
1030+
checkpoint_path = _get_checkpoint_path(tempdir)
1031+
1032+
# Load the model
1033+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1034+
1035+
# Run inference on the text
1036+
output_inference = loaded_model.run(
1037+
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
1038+
)
1039+
assert len(output_inference) > 0
1040+
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
1041+
1042+
1043+
def test_run_training_with_skip_large_text_handler():
1044+
"""Ensure that we can train succesfully after using skip large text handler."""
1045+
with tempfile.TemporaryDirectory() as tempdir:
1046+
1047+
data_args = copy.deepcopy(DATA_ARGS)
1048+
1049+
# set training_data_path and response_template to none
1050+
data_args.response_template = None
1051+
data_args.training_data_path = None
1052+
1053+
dataconfigfile = DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER
1054+
datapath = TWITTER_COMPLAINTS_TOKENIZED_JSON
1055+
1056+
# add data_paths in data_config file
1057+
with tempfile.NamedTemporaryFile(
1058+
"w", delete=False, suffix=".yaml"
1059+
) as temp_yaml_file:
1060+
with open(dataconfigfile, "r", encoding="utf-8") as f:
1061+
data = yaml.safe_load(f)
1062+
datasets = data["datasets"]
1063+
for _, d in enumerate(datasets):
1064+
d["data_paths"] = [datapath]
1065+
yaml.dump(data, temp_yaml_file)
1066+
data_args.data_config_path = temp_yaml_file.name
1067+
1068+
train_args = copy.deepcopy(TRAIN_ARGS)
1069+
train_args.output_dir = tempdir
1070+
1071+
sft_trainer.train(MODEL_ARGS, data_args, train_args)
1072+
1073+
# validate full ft configs
1074+
_validate_training(tempdir)
1075+
checkpoint_path = _get_checkpoint_path(tempdir)
1076+
1077+
# Load the model
1078+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1079+
1080+
# Run inference on the text
1081+
output_inference = loaded_model.run(
1082+
"### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50
1083+
)
1084+
assert len(output_inference) > 0
1085+
assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference
1086+
1087+
9991088
@pytest.mark.parametrize(
10001089
"dataset_path",
10011090
[CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN],
@@ -1656,7 +1745,8 @@ def test_run_with_bad_additional_data_handlers(additional_handlers):
16561745
train_args.output_dir = tempdir
16571746

16581747
with pytest.raises(
1659-
ValueError, match="Handlers should be of type Dict, str to callable"
1748+
ValueError,
1749+
match="Handler should be of type tuning.data_handler.DataHandler, and name of str",
16601750
):
16611751
sft_trainer.train(
16621752
MODEL_ARGS,
@@ -1725,6 +1815,12 @@ def test_handler(element, tokenizer, **kwargs):
17251815
DATA_ARGS,
17261816
train_args,
17271817
PEFT_PT_ARGS,
1728-
additional_data_handlers={TEST_HANDLER: test_handler},
1818+
additional_data_handlers={
1819+
TEST_HANDLER: DataHandler(
1820+
op=test_handler,
1821+
handler_type=DataHandlerType.MAP,
1822+
allows_batching=False,
1823+
)
1824+
},
17291825
)
17301826
_validate_training(tempdir)

0 commit comments

Comments
 (0)