Skip to content

Commit 11cd801

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 5246ed2 + 65d5e54 commit 11cd801

File tree

8 files changed

+289
-2
lines changed

8 files changed

+289
-2
lines changed

README.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,83 @@ For the [granite model above](https://huggingface.co/ibm-granite/granite-3.0-8b-
184184

185185
The code internally uses [`DataCollatorForCompletionOnlyLM`](https://github.com/huggingface/trl/blob/main/trl/trainer/utils.py#L93) to perform masking of text ensuring model learns only on the `assistant` responses for both single and multi turn chat.
186186

187+
#### Aligning dataset formats
188+
In some cases the chat template might not be aligned with the data format of the dataset. For example, consider the following data sample and suppose we want to use the list of contents associated with the `messages` key from the data sample for our multi-turn training job!
189+
190+
```
191+
{
192+
"messages": [
193+
{"content": "You are an AI...", "role": "system"},
194+
{"content": "Look up a word...", "role": "user"},
195+
{"content": "A word that rhymes is 'mist'", "role": "assistant"}
196+
],
197+
"group": "lab_extension",
198+
"dataset": "base/full-extension",
199+
"metadata": "{\"num_turns\": 2}"
200+
}
201+
```
202+
Different Chat templates support different data formats and the chat template might not always align with the data format of the dataset!
203+
204+
Here is a example of chat template that iterates over the nested data sample by addressing the "messages" key in `for message in messages['messages']` :
205+
```
206+
{% for message in messages['messages'] %}\
207+
{% if message['role'] == 'user' %}{{ '<|user|>\n' + message['content'] + eos_token }}\
208+
{% elif message['role'] == 'system' %}{{ '<|system|>\n' + message['content'] + eos_token }}\
209+
{% elif message['role'] == 'assistant' %}{{ '<|assistant|>\n' + message['content'] + eos_token }}\
210+
{% endif %}\
211+
{% if loop.last and add_generation_prompt %}{{ '<|assistant|>' }}\
212+
{% endif %}\
213+
{% endfor %}
214+
```
215+
While the above template might be suitable for certain data formats, not all chat templates access the nested contents in a data sample.
216+
217+
In the following example notice the `for message in messages` line which does not access any nested contents in the data and expects the nested content to be passed directly to the chat template!
218+
219+
```
220+
{%- for message in messages %}\
221+
{%- if message['role'] == 'system' %}\
222+
{{- '<|system|>\n' + message['content'] + '\n' }}\
223+
{%- elif message['role'] == 'user' %}\
224+
{{- '<|user|>\n' + message['content'] + '\n' }}\
225+
{%- elif message['role'] == 'assistant' %}\
226+
{%- if not loop.last %}\
227+
{{- '<|assistant|>\n' + message['content'] + eos_token + '\n' }}\
228+
{%- else %}\
229+
{{- '<|assistant|>\n' + message['content'] + eos_token }}\
230+
{%- endif %}\
231+
{%- endif %}\
232+
{%- if loop.last and add_generation_prompt %}\
233+
{{- '<|assistant|>\n' }}\
234+
{%- endif %}\
235+
{%- endfor %}
236+
```
237+
238+
When working with multi-turn datasets, it's often necessary to extract specific fields from the data depending on the format. For example, in many multi-turn datasets, conversations may be stored under a dedicated key (e.g., `conversations`, `messages`, etc), and you may only need the content of that key for processing.
239+
240+
```
241+
{
242+
"conversations": [
243+
{"content": "You are an AI...", "role": "system"},
244+
{"content": "Look up a word...", "role": "user"},
245+
{"content": "A word that rhymes is 'mist'", "role": "assistant"}
246+
],
247+
"group": "lab_extension",
248+
"dataset": "base/full-extension",
249+
"metadata": "{\"num_turns\": 2}"
250+
}
251+
252+
```
253+
To extract and use the conversations field, pass the following flag when running:
254+
```
255+
--dataset_conversation_field "conversations"
256+
```
257+
258+
*Note:* For most cases, users using `Granite3.1+ Instruct` series models which already contain chat template should look to pass `--dataset_conversation_field "messages"` while using multi-turn data on the commandline or use `conversations_column` argument in the [data handler](https://github.com/foundation-model-stack/fms-hf-tuning/blob/30ceecc63f3e2bf3aadba2dfc3336b62187c240f/tests/artifacts/predefined_data_configs/mt_data_granite_3_1B_tokenize_and_mask_handler.yaml#L63) which processes chat template
259+
260+
We recommend inspecting the data and chat template to decide if you need to pass this flag.
261+
262+
### Guidelines
263+
187264
Depending on various scenarios users might need to decide on how to use chat template with their data or which chat template to use for their use case.
188265

189266
Following are the Guidelines from us in a flow chart :

tests/artifacts/testdata/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@
6868
)
6969
CHAT_DATA_SINGLE_TURN = os.path.join(JSONL_DATA_DIR, "single_turn_chat.jsonl")
7070
CHAT_DATA_MULTI_TURN = os.path.join(JSONL_DATA_DIR, "multi_turn_chat.jsonl")
71+
CHAT_DATA_MULTI_TURN_CONVERSATIONS = os.path.join(
72+
JSONL_DATA_DIR, "multi_turn_chat_conversations.jsonl"
73+
)
7174
CHAT_DATA_MULTI_TURN_GRANITE_3_1B = os.path.join(
7275
JSONL_DATA_DIR, "multi_turn_chat_granite_instruct.jsonl"
7376
)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Using the word \"grace\", come up with a word that rhymes and has the same number of syllables\n<nopace>", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"grace\" and has the same number of syllables:\n1\\. Space", "role": "assistant"}, {"content": "Using the word \"grace\", come up with a word that rhymes and has the same number of syllables\n<nopace>", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"grace\" and has the same number of syllables:\n1\\. Space", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
2+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Tell me one word that rhymes and has the opposite meaning of \"open\".", "role": "user"}, {"content": "Sure! Here's a word that rhymes with \"open\" and has the opposite meaning:\n1\\. Shut", "role": "assistant"}, {"content": "Tell me one word that rhymes and has the opposite meaning of \"open\".", "role": "user"}, {"content": "Sure! Here's a word that rhymes with \"open\" and has the opposite meaning:\n1\\. Shut", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
3+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "For the word \"dream\", give an example of a word that rhymes with it and its synonym.", "role": "user"}, {"content": "Here's an example for \"dream\" that includes a word that rhymes with it and a synonym:\n1\\. Word that rhymes with \"dream\": \"beam\"\nSynonym: \"ideal\"", "role": "assistant"}, {"content": "For the word \"dream\", give an example of a word that rhymes with it and its synonym.", "role": "user"}, {"content": "Here's an example for \"dream\" that includes a word that rhymes with it and a synonym:\n1\\. Word that rhymes with \"dream\": \"beam\"\nSynonym: \"ideal\"", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
4+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Can you think of a word that rhymes with tree?", "role": "user"}, {"content": "Of course, I can help! Here's one word that rhymes with \"tree\":\n1\\. Three", "role": "assistant"}, {"content": "Can you think of a word that rhymes with tree?", "role": "user"}, {"content": "Of course, I can help! Here's one word that rhymes with \"tree\":\n1\\. Three", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
5+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Look up a word that rhymes with exist", "role": "user"}, {"content": "I found a word that rhymes with \"exist\":\n1\\. Mist", "role": "assistant"}, {"content": "Look up a word that rhymes with exist", "role": "user"}, {"content": "I found a word that rhymes with \"exist\":\n1\\. Mist", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
6+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Provide five words that rhyme with \"green\".", "role": "user"}, {"content": "Here are five words that rhyme with \"green\":\n1\\. Bean\n2\\. Clean\n3\\. Keen\n4\\. Mean\n5\\. Seen", "role": "assistant"}, {"content": "Provide five words that rhyme with \"green\".", "role": "user"}, {"content": "Here are five words that rhyme with \"green\":\n1\\. Bean\n2\\. Clean\n3\\. Keen\n4\\. Mean\n5\\. Seen", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
7+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "What are three words that rhyme with \"peace\"?", "role": "user"}, {"content": "Here are three words that rhyme with \"peace\":\n1\\. Breeze\n2\\. Pieces\n3\\. Increase", "role": "assistant"}, {"content": "What are three words that rhyme with \"peace\"?", "role": "user"}, {"content": "Here are three words that rhyme with \"peace\":\n1\\. Breeze\n2\\. Pieces\n3\\. Increase", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
8+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Can you name a word that rhymes with high and starts with the letter \"t\"?", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"high\" and starts with \"t\":\n1\\. Sky", "role": "assistant"}, {"content": "Can you name a word that rhymes with high and starts with the letter \"t\"?", "role": "user"}, {"content": "Certainly! Here's a word that rhymes with \"high\" and starts with \"t\":\n1\\. Sky", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
9+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Identify the word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\".", "role": "user"}, {"content": "The word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\" is:\n1\\. \"Shoe\".", "role": "assistant"}, {"content": "Identify the word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\".", "role": "user"}, {"content": "The word that rhymes with \"love\" and has the same vowel sounds in the same order as the word \"busy\" is:\n1\\. \"Shoe\".", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}
10+
{"conversations": [{"content": "You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.", "role": "system"}, {"content": "Provide two rhyming words for the word \"love\".\n<no above>", "role": "user"}, {"content": "Here are two rhyming words for \"love\":\n1\\. Dove\n2\\. Glove", "role": "assistant"}, {"content": "Provide two rhyming words for the word \"love\".\n<no above>", "role": "user"}, {"content": "Here are two rhyming words for \"love\":\n1\\. Dove\n2\\. Glove", "role": "assistant"}], "group": "lab_extension", "dataset": "base/full-extension", "metadata": "{\"num_turns\": 2}"}

tests/test_sft_trainer.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
)
5252
from tests.artifacts.testdata import (
5353
CHAT_DATA_MULTI_TURN,
54+
CHAT_DATA_MULTI_TURN_CONVERSATIONS,
5455
CHAT_DATA_MULTI_TURN_GRANITE_3_1B,
5556
CHAT_DATA_SINGLE_TURN,
5657
CUSTOM_TOKENIZER_TINYLLAMA,
@@ -124,13 +125,18 @@
124125
PEFT_LORA_ARGS = peft_config.LoraConfig(r=8, lora_alpha=32, lora_dropout=0.05)
125126

126127

127-
def test_resume_training_from_checkpoint():
128+
@pytest.mark.parametrize(
129+
"enable_reduce_loss_sum",
130+
[False, True],
131+
)
132+
def test_resume_training_from_checkpoint(enable_reduce_loss_sum):
128133
"""
129134
Test tuning resumes from the latest checkpoint, creating new checkpoints and the
130135
checkpoints created before resuming tuning is not affected.
131136
"""
132137
with tempfile.TemporaryDirectory() as tempdir:
133138
train_args = copy.deepcopy(TRAIN_ARGS)
139+
train_args.enable_reduce_loss_sum = enable_reduce_loss_sum
134140
train_args.output_dir = tempdir
135141

136142
sft_trainer.train(MODEL_ARGS, DATA_ARGS, train_args, None)
@@ -1139,6 +1145,59 @@ def test_run_chat_style_ft(dataset_path):
11391145
assert 'Provide two rhyming words for the word "love"' in output_inference
11401146

11411147

1148+
def test_run_chat_style_ft_dataset_conversation_field():
1149+
"""Check if we can perform an e2e run with chat template and multi turn chat training."""
1150+
with tempfile.TemporaryDirectory() as tempdir:
1151+
1152+
data_args = copy.deepcopy(DATA_ARGS)
1153+
data_args.training_data_path = CHAT_DATA_MULTI_TURN_CONVERSATIONS
1154+
1155+
# sampled chat template from granite3.1 instruct model
1156+
data_args.chat_template = "{%- for message in messages %}\
1157+
{%- if message['role'] == 'system' %}\
1158+
{{- '<|system|>\n' + message['content'] + '\n' }}\
1159+
{%- elif message['role'] == 'user' %}\
1160+
{{- '<|user|>\n' + message['content'] + '\n' }}\
1161+
{%- elif message['role'] == 'assistant' %}\
1162+
{%- if not loop.last %}\
1163+
{{- '<|assistant|>\n' + message['content'] + eos_token + '\n' }}\
1164+
{%- else %}\
1165+
{{- '<|assistant|>\n' + message['content'] + eos_token }}\
1166+
{%- endif %}\
1167+
{%- endif %}\
1168+
{%- if loop.last and add_generation_prompt %}\
1169+
{{- '<|assistant|>\n' }}\
1170+
{%- endif %}\
1171+
{%- endfor %}"
1172+
data_args.response_template = "<|assistant|>"
1173+
data_args.instruction_template = "<|user|>"
1174+
data_args.dataset_conversation_field = "conversations"
1175+
1176+
model_args = copy.deepcopy(MODEL_ARGS)
1177+
model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA
1178+
1179+
train_args = copy.deepcopy(TRAIN_ARGS)
1180+
train_args.output_dir = tempdir
1181+
1182+
sft_trainer.train(model_args, data_args, train_args)
1183+
1184+
# validate the configs
1185+
_validate_training(tempdir)
1186+
checkpoint_path = _get_checkpoint_path(tempdir)
1187+
1188+
# Load the model
1189+
loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME)
1190+
1191+
# Run inference on the text
1192+
output_inference = loaded_model.run(
1193+
'<|user|>\nProvide two rhyming words for the word "love"\n\
1194+
<nopace></s><|assistant|>',
1195+
max_new_tokens=50,
1196+
)
1197+
assert len(output_inference) > 0
1198+
assert 'Provide two rhyming words for the word "love"' in output_inference
1199+
1200+
11421201
def test_run_chat_style_add_special_tokens_ft():
11431202
"""Test to check an e2e multi turn chat training by adding special tokens via command line."""
11441203
with tempfile.TemporaryDirectory() as tempdir:

tuning/config/configs.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ class DataArguments:
7777
or data_formatter_template need to be supplied."
7878
},
7979
)
80+
dataset_conversation_field: str = field(
81+
default=None,
82+
metadata={
83+
"help": "Training dataset text field containing multi-turn chat data. \
84+
Used as key to point multi-turn data field."
85+
},
86+
)
8087
validation_data_path: str = field(
8188
default=None,
8289
metadata={"help": "Path to the validation data in JSON/JSONL format."},
@@ -201,6 +208,15 @@ class TrainingArguments(transformers.TrainingArguments):
201208
Other possible values are 'debug', 'info', 'warning', 'error' and 'critical'"
202209
},
203210
)
211+
enable_reduce_loss_sum: bool = field(
212+
default=False,
213+
metadata={
214+
"help": "Pass `True` to enable use of sum loss reduction on the loss function. \
215+
Please note this feature is experimental and not fully supported. \
216+
One Known limitation of this function is PEFT PT so its disabled \
217+
for all PEFT runs by the library internally."
218+
},
219+
)
204220

205221

206222
@dataclass

tuning/data/setup_dataprocessor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ def _get_chat_dataset_handlers(data_args, tokenizer_kwargs):
190190
fn_kwargs = {}
191191
fn_kwargs["dataset_text_field"] = data_args.dataset_text_field
192192
fn_kwargs["tokenizer_kwargs"] = tokenizer_kwargs
193+
if data_args.dataset_conversation_field is not None:
194+
fn_kwargs["conversation_column"] = data_args.dataset_conversation_field
193195

194196
kwargs = {"fn_kwargs": fn_kwargs, "batched": False, "remove_columns": "all"}
195197

tuning/sft_trainer.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from peft.utils.other import fsdp_auto_wrap_policy
2828
from torch.cuda import OutOfMemoryError
2929
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainerCallback
30+
from transformers.trainer import _is_peft_model
3031
from transformers.trainer_utils import get_last_checkpoint
3132
from transformers.utils import is_accelerate_available
3233
from trl import SFTConfig, SFTTrainer
@@ -53,6 +54,7 @@
5354
from tuning.data.setup_dataprocessor import process_dataargs
5455
from tuning.trackers.tracker_factory import FILE_LOGGING_TRACKER, get_tracker
5556
from tuning.trainercontroller import TrainerControllerCallback
57+
from tuning.trainers.sum_loss_sft_trainer import SumLossSFTTrainer
5658
from tuning.utils.config_utils import get_hf_peft_config, get_json_config
5759
from tuning.utils.data_type_utils import get_torch_dtype
5860
from tuning.utils.error_logging import (
@@ -339,7 +341,12 @@ def train(
339341
}
340342
training_args = SFTConfig(**transformer_kwargs, **additional_args)
341343

342-
trainer = SFTTrainer(
344+
if train_args.enable_reduce_loss_sum:
345+
TrainerClass = SumLossSFTTrainer
346+
else:
347+
TrainerClass = SFTTrainer
348+
349+
trainer = TrainerClass(
343350
model=model,
344351
tokenizer=tokenizer,
345352
train_dataset=formatted_train_dataset,
@@ -350,6 +357,13 @@ def train(
350357
peft_config=peft_config,
351358
)
352359

360+
# Note this check has to be done post trainer initialization
361+
if train_args.enable_reduce_loss_sum and _is_peft_model(trainer.model):
362+
raise ValueError(
363+
"❌ Loaded model is PEFT model but both PEFT and sum loss is not supported yet."
364+
+ "Set --enable_reduce_loss_sum to false and retry"
365+
)
366+
353367
# We track additional metrics and experiment metadata after trainer object creation
354368
# this ensure that the process is not repeated multiple times for FSDP runs.
355369
if trainer.is_world_process_zero():

0 commit comments

Comments
 (0)