Skip to content

Commit 6757c0b

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 6108364 + da8ae5d commit 6757c0b

File tree

9 files changed

+173
-0
lines changed

9 files changed

+173
-0
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,19 @@ Here are some scenarios addressed in the flow chart:
187187
3. There might be special tokens used in chat template which the tokenizer might be unaware of, for example `<|start_of_role|>` which can cause issues during tokenization as it might not be treated as a single token
188188

189189

190+
#### Add Special Tokens
191+
Working with multi-turn chat data might require the tokenizer to use a few new control tokens ( ex: `<|assistant|>`, `[SYS]` ) as described above in the guidelines. These special tokens might not be present in the tokenizer's vocabulary if the user is using base model.
192+
193+
Users can pass `--add_special_tokens` argument which would add the required tokens to the tokenizer's vocabulary.
194+
For example required special tokens used in `--instruction_template`/`--response_template` can be passed as follows:
195+
196+
```
197+
python -m tuning.sft_trainer \
198+
...
199+
--add_special_tokens "<|start_of_role|>" "<|end_of_role|>" \
200+
--instruction_template "<|start_of_role|>user<|end_of_role|>" \
201+
--response_template "<|start_of_role|>assistant<|end_of_role|>"
202+
```
190203

191204
### 4. Pre tokenized datasets.
192205

docs/advanced-data-preprocessing.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ definitions:
4747
type: string
4848
seed:
4949
type: integer
50+
chat_template:
51+
type: string
5052
required:
5153
- type
5254
title: Dataprocessor
@@ -118,6 +120,7 @@ Users can create a data config file in any of YAML or JSON format they choose (w
118120
- `streaming` (optional, bool): Stream datasets using [IterableDatasets](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.IterableDataset).
119121
- `sampling_stopping_strategy` (optional, str): Dataset interleave stopping strategy in case of choosing to mix multiple datasets by weight, supported values are [`all_exhausted` or `first_exhausted`](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.stopping_strategy), defaults to `all_exhausted`.
120122
- `sampling_seed` (optional, int): [Sampling seed](https://huggingface.co/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.interleave_datasets.seed) to use for interleaving datasets, for reproducibility choose same value, defaults to 42.
123+
- `chat_template` (optional, str): pass `chat_template` via data_config for multi-turn data, replaces existing default chat template.
121124

122125
`datasets` (list):
123126
- `name` (optional, str): A unique identifier for the dataset.

tests/artifacts/predefined_data_configs/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML = os.path.join(
3535
PREDEFINED_DATA_CONFIGS, "multiple_datasets_with_sampling.yaml"
3636
)
37+
DATA_CONFIG_MULTITURN_DATA_YAML = os.path.join(
38+
PREDEFINED_DATA_CONFIGS, "multi_turn_data_with_chat_template.yaml"
39+
)
3740
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT = os.path.join(
3841
PREDEFINED_DATA_CONFIGS, "tokenize_and_apply_input_masking_streaming.yaml"
3942
)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
dataprocessor:
2+
type: default
3+
chat_template: |
4+
{% for message in messages['messages'] %}
5+
{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + eos_token }}
6+
{% elif message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + eos_token }}
7+
{% elif message['role'] == 'assistant' %}{{ '<|assistant|>\n' + message['content'] + eos_token }}
8+
{% endif %}
9+
{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' }}
10+
{% endif %}
11+
{% endfor %}
12+
datasets:
13+
- name: dataset_1
14+
data_paths:
15+
- "FILE_PATH"
16+
data_handlers:
17+
- name: apply_tokenizer_chat_template
18+
arguments:
19+
fn_kwargs:
20+
dataset_text_field: formatted_chat_data

tests/test_sft_trainer.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from tests.artifacts.predefined_data_configs import (
4040
DATA_CONFIG_DUPLICATE_COLUMNS,
4141
DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML,
42+
DATA_CONFIG_MULTITURN_DATA_YAML,
4243
DATA_CONFIG_RENAME_RETAIN_COLUMNS,
4344
DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML,
4445
DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT,
@@ -1041,6 +1042,35 @@ def test_run_chat_style_ft(dataset_path):
10411042
assert 'Provide two rhyming words for the word "love"' in output_inference
10421043

10431044

1045+
def test_run_chat_style_add_special_tokens_ft():
1046+
"""Test to check an e2e multi turn chat training by adding special tokens via command line."""
1047+
with tempfile.TemporaryDirectory() as tempdir:
1048+
1049+
# sample hugging face dataset id
1050+
data_args = configs.DataArguments(
1051+
training_data_path="lhoestq/demo1",
1052+
data_formatter_template="### Text:{{review}} \n\n### Stars: {{star}}",
1053+
response_template="\n### Stars:",
1054+
add_special_tokens=["<|assistant|>", "<|user|>"],
1055+
)
1056+
1057+
train_args = copy.deepcopy(TRAIN_ARGS)
1058+
train_args.output_dir = tempdir
1059+
1060+
sft_trainer.train(MODEL_ARGS, data_args, train_args)
1061+
1062+
# validate the configs
1063+
_validate_training(tempdir)
1064+
checkpoint_path = _get_checkpoint_path(tempdir)
1065+
1066+
# Load the tokenizer
1067+
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_path)
1068+
1069+
# Check if all special tokens passed are in tokenizer
1070+
for tok in data_args.add_special_tokens:
1071+
assert tok in tokenizer.vocab
1072+
1073+
10441074
@pytest.mark.parametrize(
10451075
"datafiles, dataconfigfile",
10461076
[
@@ -1117,6 +1147,76 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile):
11171147
assert 'Provide two rhyming words for the word "love"' in output_inference
11181148

11191149

1150+
@pytest.mark.parametrize(
1151+
"datafiles, dataconfigfile",
1152+
[
1153+
(
1154+
[CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN, CHAT_DATA_SINGLE_TURN],
1155+
DATA_CONFIG_MULTITURN_DATA_YAML,
1156+
)
1157+
],
1158+
)
1159+
def test_run_chat_style_ft_using_dataconfig_for_chat_template(
1160+
datafiles, dataconfigfile
1161+
):
1162+
"""Check if we can perform an e2e run with chat template
1163+
and multi turn chat training using data config."""
1164+
with tempfile.TemporaryDirectory() as tempdir:
1165+
1166+
data_args = copy.deepcopy(DATA_ARGS)
1167+
data_args.response_template = "<|assistant|>"
1168+
data_args.instruction_template = "<|user|>"
1169+
data_args.dataset_text_field = "new_formatted_field"
1170+
1171+
handler_kwargs = {"dataset_text_field": data_args.dataset_text_field}
1172+
kwargs = {
1173+
"fn_kwargs": handler_kwargs,
1174+
"batched": False,
1175+
"remove_columns": "all",
1176+
}
1177+
1178+
handler_config = DataHandlerConfig(
1179+
name="apply_tokenizer_chat_template", arguments=kwargs
1180+
)
1181+
1182+
model_args = copy.deepcopy(MODEL_ARGS)
1183+
model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA
1184+
1185+
train_args = copy.deepcopy(TRAIN_ARGS)
1186+
train_args.output_dir = tempdir
1187+
1188+
with tempfile.NamedTemporaryFile(
1189+
"w", delete=False, suffix=".yaml"
1190+
) as temp_yaml_file:
1191+
with open(dataconfigfile, "r", encoding="utf-8") as f:
1192+
data = yaml.safe_load(f)
1193+
datasets = data["datasets"]
1194+
for i, d in enumerate(datasets):
1195+
d["data_paths"] = [datafiles[i]]
1196+
# Basic chat datasets don't need data handling
1197+
d["data_handlers"] = [asdict(handler_config)]
1198+
yaml.dump(data, temp_yaml_file)
1199+
data_args.data_config_path = temp_yaml_file.name
1200+
1201+
sft_trainer.train(model_args, data_args, train_args)
1202+
1203+
# validate the configs
1204+
_validate_training(tempdir)
1205+
checkpoint_path = _get_checkpoint_path(tempdir)
1206+
1207+
# Load the model
1208+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1209+
1210+
# Run inference on the text
1211+
output_inference = loaded_model.run(
1212+
'<|user|>\nProvide two rhyming words for the word "love"\n\
1213+
<nopace></s><|assistant|>',
1214+
max_new_tokens=50,
1215+
)
1216+
assert len(output_inference) > 0
1217+
assert 'Provide two rhyming words for the word "love"' in output_inference
1218+
1219+
11201220
@pytest.mark.parametrize(
11211221
"data_args",
11221222
[

tuning/config/configs.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,14 @@ class DataArguments:
122122
Passed in conjunction with response_template"
123123
},
124124
)
125+
add_special_tokens: List[str] = field(
126+
default=None,
127+
metadata={
128+
"help": "List of special tokens to be added to the tokenizer's vocabulary. \
129+
Used to add Special Tokens to Tokenizer's Vocabulary,\
130+
Add special tokens as new tokens and increase vocabulary and model embedding size."
131+
},
132+
)
125133

126134

127135
@dataclass

tuning/data/data_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class DataPreProcessorConfig:
4848
# Default seed is not none to ensure reproducability
4949
sampling_seed: Optional[float] = 42
5050
streaming: Optional[bool] = False
51+
chat_template: Optional[str] = None
5152

5253

5354
@dataclass
@@ -147,6 +148,10 @@ def _validate_dataprocessor_config(dataprocessor_config) -> DataPreProcessorConf
147148
streaming = kwargs["streaming"]
148149
assert isinstance(streaming, bool), f"streaming: {streaming} should be a bool"
149150
c.streaming = streaming
151+
if "chat_template" in kwargs:
152+
chat_template = kwargs["chat_template"]
153+
assert isinstance(chat_template, str), "chat_template should be a string"
154+
c.chat_template = chat_template
150155
return c
151156

152157

tuning/data/setup_dataprocessor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,16 @@ def _process_dataconfig_file(
7575
tokenizer=tokenizer,
7676
additional_data_handlers=additional_data_handlers,
7777
)
78+
79+
if processor.processor_config.chat_template is not None:
80+
if tokenizer.chat_template:
81+
logger.warning(
82+
"replacing existing chat_template %s with data config's chat_template %s",
83+
tokenizer.chat_template,
84+
processor.processor_config.chat_template,
85+
)
86+
tokenizer.chat_template = processor.processor_config.chat_template
87+
7888
if processor.processor_config.streaming:
7989
if train_args.max_steps < 1:
8090
logging.error(

tuning/sft_trainer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,10 @@ def train(
250250
)
251251

252252
if data_args.chat_template:
253+
# TODO: passing "/n" through cli causes parsing issues,
254+
# hence providing a temporary fix
255+
data_args.chat_template = data_args.chat_template.replace(r"\n", "\n")
256+
253257
logger.info("adding chat_template to the tokenizer")
254258
if tokenizer.chat_template:
255259
logger.warning(
@@ -297,6 +301,13 @@ def train(
297301
tokenizer.eos_token = configs.DEFAULT_EOS_TOKEN
298302
special_tokens_dict["eos_token"] = configs.DEFAULT_EOS_TOKEN
299303

304+
# adds user specified special tokens to vocab
305+
if data_args.add_special_tokens:
306+
logger.info(
307+
"Adding user-defined special tokens: %s ", data_args.add_special_tokens
308+
)
309+
special_tokens_dict["additional_special_tokens"] = data_args.add_special_tokens
310+
300311
# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
301312
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
302313
added_tokens_dict = tokenizer_and_embedding_resize(

0 commit comments

Comments
 (0)