Skip to content

Commit 5246ed2

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents b2cd930 + 8d4ba0b commit 5246ed2

File tree

12 files changed

+486
-112
lines changed

12 files changed

+486
-112
lines changed

docs/advanced-data-preprocessing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ These things are supported via what we call a [`data_config`](#data-config) whic
99

1010
## Data Config
1111

12-
Data config is a configuration file which `sft_trainer.py` supports as an argument via `--data_config` flag. In this
12+
Data config is a configuration file which `sft_trainer.py` supports as an argument via `--data_config_path` flag. In this
1313
configuration users can describe multiple datasets, configurations on how to load the datasets and configuration on how to
1414
process the datasets. Users can currently pass both YAML or JSON based configuration files as data_configs.
1515

docs/ept.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ datasets:
4343
And the commandline passed to the library should include following.
4444

4545
```
46-
--data_config <path to the data config> --packing=True --max_seq_len 8192
46+
--data_config_path <path to the data config> --packing=True --max_seq_len 8192
4747
```
4848

4949
Please note that for non tokenized dataset our code adds `EOS_TOKEN` to the lines, for e.g. `Tweet` column before passing that as a dataset.
@@ -102,7 +102,7 @@ NOTE: More in-depth documentation of `sampling_stopping_strategy` and how to spe
102102
Here also the command line arguments would be
103103

104104
```
105-
--data_config <path to the data config> --packing=True --max_seq_len 8192
105+
--data_config_path <path to the data config> --packing=True --max_seq_len 8192
106106
```
107107

108108
The code again would add `EOS_TOKEN` to the non tokenized data before using it and also note that the `dataset_text_field` is assumed to be same across all datasets for now.
@@ -131,7 +131,7 @@ datasets:
131131
The command-line arguments passed to the library should include the following:
132132

133133
```
134-
--data_config <path to the data config> --packing=True --max_seq_len 8192 --max_steps <num training steps>
134+
--data_config_path <path to the data config> --packing=True --max_seq_len 8192 --max_steps <num training steps>
135135
```
136136

137137
Please note when using streaming, user must pass `max_steps` instead of `num_train_epochs`. See advanced data preprocessing [document](./advanced-data-preprocessing.md#data-streaming) for more info.

scripts/offline_data_processing.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,15 @@
55
import traceback
66

77
# Third Party
8-
from transformers import (
9-
AutoTokenizer,
10-
GPT2Tokenizer,
11-
GPTNeoXTokenizerFast,
12-
LlamaTokenizer,
13-
LlamaTokenizerFast,
14-
)
8+
from transformers import AutoTokenizer
159

1610
# Local
1711
from tuning.config import configs
1812
from tuning.data.setup_dataprocessor import process_dataargs
1913
from tuning.sft_trainer import get_parser
2014
from tuning.utils.error_logging import USER_ERROR_EXIT_CODE, write_termination_log
2115
from tuning.utils.logging import set_log_level
16+
from tuning.utils.tokenizer_data_utils import get_special_tokens_dict
2217

2318

2419
def save_dataset_shards(
@@ -92,36 +87,9 @@ def get_processed_dataset(
9287
tokenizer.chat_template = data_args.chat_template
9388

9489
# Prepare special tokens dictionary
95-
special_tokens_dict = {}
96-
if not model_args.tokenizer_name_or_path:
97-
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
98-
special_tokens_dict["bos_token"] = "<s>"
99-
special_tokens_dict["eos_token"] = "</s>"
100-
special_tokens_dict["unk_token"] = "<unk>"
101-
special_tokens_dict["pad_token"] = "<pad>"
102-
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
103-
special_tokens_dict["pad_token"] = "<pad>"
104-
105-
if tokenizer.pad_token is None:
106-
logger.warning(
107-
"PAD token not found in tokenizer; setting PAD token to default."
108-
)
109-
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
110-
if tokenizer.eos_token is None:
111-
logger.warning(
112-
"EOS token not found in tokenizer; setting EOS token to default."
113-
)
114-
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
115-
if tokenizer.pad_token == tokenizer.eos_token:
116-
logger.warning(
117-
"PAD token and EOS token are the same. Overriding accordingly."
118-
)
119-
if tokenizer.eos_token != configs.DEFAULT_PAD_TOKEN:
120-
tokenizer.pad_token = configs.DEFAULT_PAD_TOKEN
121-
special_tokens_dict["pad_token"] = configs.DEFAULT_PAD_TOKEN
122-
else:
123-
tokenizer.eos_token = configs.DEFAULT_EOS_TOKEN
124-
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
90+
special_tokens_dict = get_special_tokens_dict(
91+
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
92+
)
12593

12694
# adds user specified special tokens to vocab
12795
if data_args.add_special_tokens:

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@
4040
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML = os.path.join(
4141
PREDEFINED_DATA_CONFIGS, "multi_turn_data_with_chat_template_granite_3_1B.yaml"
4242
)
43+
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER = os.path.join(
44+
PREDEFINED_DATA_CONFIGS, "mt_data_granite_3_1B_tokenize_and_mask_handler.yaml"
45+
)
4346
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
4447
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
4548
)
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
dataprocessor:
2+
type: default
3+
chat_template: |
4+
{%- if messages[0]['role'] == 'system' %}
5+
{%- set system_message = messages[0]['content'] %}
6+
{%- set loop_messages = messages[1:] %}
7+
{%- else %}
8+
{%- set system_message = "Knowledge Cutoff Date: April 2024.\nToday's Date: " + strftime_now('%B %d, %Y') + ".\nYou are Granite, developed by IBM." %}
9+
{%- if tools and documents %}
10+
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request.\n\nWrite the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
11+
{%- elif tools %}
12+
{%- set system_message = system_message + " You are a helpful AI assistant with access to the following tools. When a tool is required to answer the user's query, respond with <|tool_call|> followed by a JSON list of tools used. If a tool does not exist in the provided list of tools, notify the user that you do not have the ability to fulfill the request." %}
13+
{%- elif documents %}
14+
{%- set system_message = system_message + " Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data." %}
15+
{%- else %}
16+
{%- set system_message = system_message + " You are a helpful AI assistant." %}
17+
{%- endif %}
18+
{%- if 'citations' in controls and documents %}
19+
{%- set system_message = system_message + '\n\nIn your response, use the symbols <co> and </co> to indicate when a fact comes from a document in the search result, e.g <co>0</co> for a fact from document 0. Afterwards, list all the citations with their corresponding documents in an ordered list.' %}
20+
{%- endif %}
21+
{%- if 'hallucinations' in controls and documents %}
22+
{%- set system_message = system_message + '\n\nFinally, after the response is written, include a numbered list of sentences from the response that are potentially hallucinated and not based in the documents.' %}
23+
{%- endif %}
24+
{%- set loop_messages = messages %}
25+
{%- endif %}
26+
{{- '<|start_of_role|>system<|end_of_role|>' + system_message + '<|end_of_text|>\n' }}
27+
{%- if tools %}
28+
{{- '<|start_of_role|>tools<|end_of_role|>' }}
29+
{{- tools | tojson(indent=4) }}
30+
{{- '<|end_of_text|>\n' }}
31+
{%- endif %}
32+
{%- if documents %}
33+
{{- '<|start_of_role|>documents<|end_of_role|>' }}
34+
{%- for document in documents %}
35+
{{- 'Document ' + loop.index0 | string + '\n' }}
36+
{{- document['text'] }}
37+
{%- if not loop.last %}
38+
{{- '\n\n'}}
39+
{%- endif%}
40+
{%- endfor %}
41+
{{- '<|end_of_text|>\n' }}
42+
{%- endif %}
43+
{%- for message in loop_messages %}
44+
{{- '<|start_of_role|>' + message['role'] + '<|end_of_role|>' + message['content'] + '<|end_of_text|>\n' }}
45+
{%- if loop.last and add_generation_prompt %}
46+
{{- '<|start_of_role|>assistant' }}
47+
{%- if controls %}
48+
{{- ' ' + controls | tojson()}}
49+
{%- endif %}
50+
{{- '<|end_of_role|>' }}
51+
{%- endif %}
52+
{%- endfor %}
53+
datasets:
54+
- name: dataset_1
55+
data_paths:
56+
- "FILE_PATH"
57+
data_handlers:
58+
- name: tokenize_and_apply_chat_template_with_masking
59+
arguments:
60+
remove_columns: all
61+
fn_kwargs:
62+
max_seq_length: 1024
63+
conversation_column: "messages"
64+
- name: dataset_2
65+
data_paths:
66+
- "FILE_PATH"
67+
data_handlers:
68+
- name: tokenize_and_apply_chat_template_with_masking
69+
arguments:
70+
remove_columns: all
71+
fn_kwargs:
72+
max_seq_length: 1024
73+
conversation_column: "messages"
74+
- name: dataset_3
75+
data_paths:
76+
- "FILE_PATH"
77+
data_handlers:
78+
- name: tokenize_and_apply_chat_template_with_masking
79+
arguments:
80+
remove_columns: all
81+
fn_kwargs:
82+
max_seq_length: 1024
83+
conversation_column: "messages"

tests/test_sft_trainer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tests.artifacts.predefined_data_configs import (
4040
DATA_CONFIG_DUPLICATE_COLUMNS,
4141
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
42+
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER,
4243
DATA_CONFIG_MULTITURN_DATA_YAML,
4344
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML,
4445
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
@@ -1258,6 +1259,14 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
12581259
],
12591260
DATA_CONFIG_MULTITURN_GRANITE_3_1B_DATA_YAML,
12601261
),
1262+
(
1263+
[
1264+
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
1265+
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
1266+
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
1267+
],
1268+
DATA_CONFIG_MULTITURN_CHAT_TOKENIZE_AND_MASKING_DATA_HANDLER,
1269+
),
12611270
],
12621271
)
12631272
def test_run_chat_style_ft_using_dataconfig_for_chat_template(
@@ -1768,7 +1777,7 @@ def test_pretokenized_dataset_bad_args(dataset_text_field, response_template):
17681777
data_args = copy.deepcopy(DATA_ARGS)
17691778
data_args.dataset_text_field = dataset_text_field
17701779
data_args.response_template = response_template
1771-
data_args.training_data_path = TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL
1780+
data_args.training_data_path = TWITTER_COMPLAINTS_TOKENIZED_JSON
17721781
# We should raise an error since we should not have a dataset text
17731782
# field or a response template if we have pretokenized data
17741783
with pytest.raises(ValueError):
Lines changed: 150 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,164 @@
1-
# Third party
21
# Third Party
32
from transformers import AutoModelForCausalLM, AutoTokenizer
43

54
# First Party
65
from tests.artifacts.testdata import MODEL_NAME
76

87
# Local
9-
# First party
10-
from tuning.utils.tokenizer_data_utils import tokenizer_and_embedding_resize
8+
from tuning.config import configs
9+
from tuning.utils.tokenizer_data_utils import (
10+
get_special_tokens_dict,
11+
tokenizer_and_embedding_resize,
12+
)
1113

1214

13-
def test_tokenizer_and_embedding_resize_return_values():
14-
"""Test to ensure number of added tokens are returned correctly"""
15+
def test_setting_special_tokens_with_LlamaTokenizerFast():
16+
"""
17+
Unit test using a LlamaTokenizerFast tokenizer. This tokenizer is only missing a PAD token,
18+
however because it is a LlamaTokenizer, the function code automatically adds the BOS, EOS,
19+
UNK and PAD tokens to the special tokens dict. Then, the <pad> token is replaced with
20+
a <PAD> token, because the Llama tokenizer does not have a pad token specified.
21+
"""
22+
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0", legacy=True)
23+
model_args = configs.ModelArguments()
24+
special_tokens_dict = get_special_tokens_dict(
25+
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
26+
)
27+
assert special_tokens_dict == {
28+
"bos_token": "<s>",
29+
"eos_token": "</s>",
30+
"unk_token": "<unk>",
31+
"pad_token": "<PAD>",
32+
}
33+
34+
35+
def test_setting_special_tokens_with_GPT2TokenizerFast():
36+
"""
37+
Unit test using a GPT2TokenizerFast tokenizer. This tokenizer is the case where the
38+
EOS token = PAD token, both of them are <|endoftext|>. So, the pad token in the tokenizer is set
39+
to <PAD> and the "pad_token": "<PAD>" is also added to the special tokens dict.
40+
"""
41+
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-base")
42+
model_args = configs.ModelArguments()
43+
special_tokens_dict = get_special_tokens_dict(
44+
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
45+
)
46+
assert special_tokens_dict == {
47+
"pad_token": "<PAD>",
48+
}
49+
50+
51+
def test_setting_special_tokens_with_GPTNeoXTokenizerFast():
52+
"""
53+
Unit test using a GPTNeoXTokenizerFast tokenizer. This tokenizer is another one that is
54+
hardcoded into the function to automatically add just a pad token to the special tokens dict.
55+
However, the tokenizer itself is also missing a pad token, so the function then replaces
56+
the <pad> token with the default <PAD> token.
57+
"""
58+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
59+
model_args = configs.ModelArguments()
60+
special_tokens_dict = get_special_tokens_dict(
61+
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
62+
)
63+
assert special_tokens_dict == {
64+
"pad_token": "<PAD>",
65+
}
66+
67+
68+
def test_setting_special_tokens_when_missing_all_special_tokens():
69+
"""
70+
Unit test using the GPT2TokenizerFast tokenizer. All the special tokens have been
71+
removed from the tokenizer, so we expect all of them to appear in the special tokens dict.
72+
"""
73+
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-base")
74+
75+
# Set all special tokens to None
76+
tokenizer.bos_token = None
77+
tokenizer.eos_token = None
78+
tokenizer.unk_token = None
79+
tokenizer.pad_token = None
80+
81+
model_args = configs.ModelArguments()
82+
special_tokens_dict = get_special_tokens_dict(
83+
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
84+
)
85+
assert special_tokens_dict == {
86+
"pad_token": "<PAD>",
87+
"eos_token": "</s>",
88+
"bos_token": "<s>",
89+
"unk_token": "<unk>",
90+
}
91+
92+
93+
def test_setting_special_tokens_when_path_is_not_none():
94+
"""
95+
A simple unit test that sets the `tokenizer_name_or_path` argument in
96+
`model_args` to a non None value. Since the argument is not None, almost
97+
the entire `get_special_tokens_dict` function is skipped and the
98+
special tokens dict is expected to be empty.
99+
"""
100+
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0", legacy=True)
101+
model_args = configs.ModelArguments(tokenizer_name_or_path="test_path")
102+
special_tokens_dict = get_special_tokens_dict(
103+
tokenizer_name_or_path=model_args.tokenizer_name_or_path, tokenizer=tokenizer
104+
)
105+
# Assert special_tokens_dict is empty
106+
assert not special_tokens_dict
107+
108+
109+
def test_tokenizer_and_embedding_resize_return_values_missing_one_token():
110+
"""
111+
Tests the resizing function when the special tokens dict contains a PAD token,
112+
which means the tokenizer is missing one special token.
113+
114+
`mulitple_of` is set to 1.
115+
"""
15116
special_tokens_dict = {"pad_token": "<pad>"}
16117
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17118
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
18119
metadata = tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model)
19120
assert metadata["num_new_tokens"] == 1
20-
assert "new_embedding_size" in metadata
121+
assert metadata["new_embedding_size"] == len(tokenizer)
122+
123+
124+
def test_tokenizer_and_embedding_resize_return_values_missing_four_tokens():
125+
"""
126+
Tests the resizing when the special tokens dict contains a PAD, EOS, BOS and UNK token,
127+
which means the tokenizer is missing four special tokens.
128+
129+
`mulitple_of` is set to 1.
130+
"""
131+
special_tokens_dict = {
132+
"pad_token": "<PAD>",
133+
"eos_token": "</s>",
134+
"bos_token": "<s>",
135+
"unk_token": "<unk>",
136+
}
137+
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0", legacy=True)
138+
model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0")
139+
metadata = tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model)
140+
assert metadata["num_new_tokens"] == 4
141+
assert metadata["new_embedding_size"] == len(tokenizer)
142+
143+
144+
def test_tokenizer_and_embedding_resize_return_values_mutliple_of_two():
145+
"""
146+
Tests the resizing when the special tokens dict contains a PAD, EOS, BOS and UNK token,
147+
which means the tokenizer is missing four special tokens.
148+
149+
`mulitple_of` is set to 2; this add one to the count of num_new_tokens and adds
150+
one to the count of new_embedding_size.
151+
"""
152+
special_tokens_dict = {
153+
"pad_token": "<PAD>",
154+
"eos_token": "</s>",
155+
"bos_token": "<s>",
156+
"unk_token": "<unk>",
157+
}
158+
tokenizer = AutoTokenizer.from_pretrained("Maykeye/TinyLLama-v0", legacy=True)
159+
model = AutoModelForCausalLM.from_pretrained("Maykeye/TinyLLama-v0")
160+
metadata = tokenizer_and_embedding_resize(
161+
special_tokens_dict, tokenizer, model, multiple_of=2
162+
)
163+
assert metadata["num_new_tokens"] == 5
164+
assert metadata["new_embedding_size"] == len(tokenizer) + 1

0 commit comments

Comments
 (0)