|
39 | 39 | from tests.artifacts.predefined_data_configs import ( |
40 | 40 | DATA_CONFIG_DUPLICATE_COLUMNS, |
41 | 41 | DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
| 42 | + DATA_CONFIG_MULTITURN_DATA_YAML, |
42 | 43 | DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
43 | 44 | DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
44 | 45 | DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT, |
@@ -1041,6 +1042,35 @@ def test_run_chat_style_ft(dataset_path): |
1041 | 1042 | assert 'Provide two rhyming words for the word "love"' in output_inference |
1042 | 1043 |
|
1043 | 1044 |
|
| 1045 | +def test_run_chat_style_add_special_tokens_ft(): |
| 1046 | + """Test to check an e2e multi turn chat training by adding special tokens via command line.""" |
| 1047 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1048 | + |
| 1049 | + # sample hugging face dataset id |
| 1050 | + data_args = configs.DataArguments( |
| 1051 | + training_data_path="lhoestq/demo1", |
| 1052 | + data_formatter_template="### Text:{{review}} \n\n### Stars: {{star}}", |
| 1053 | + response_template="\n### Stars:", |
| 1054 | + add_special_tokens=["<|assistant|>", "<|user|>"], |
| 1055 | + ) |
| 1056 | + |
| 1057 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1058 | + train_args.output_dir = tempdir |
| 1059 | + |
| 1060 | + sft_trainer.train(MODEL_ARGS, data_args, train_args) |
| 1061 | + |
| 1062 | + # validate the configs |
| 1063 | + _validate_training(tempdir) |
| 1064 | + checkpoint_path = _get_checkpoint_path(tempdir) |
| 1065 | + |
| 1066 | + # Load the tokenizer |
| 1067 | + tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint_path) |
| 1068 | + |
| 1069 | + # Check if all special tokens passed are in tokenizer |
| 1070 | + for tok in data_args.add_special_tokens: |
| 1071 | + assert tok in tokenizer.vocab |
| 1072 | + |
| 1073 | + |
1044 | 1074 | @pytest.mark.parametrize( |
1045 | 1075 | "datafiles, dataconfigfile", |
1046 | 1076 | [ |
@@ -1117,6 +1147,76 @@ def test_run_chat_style_ft_using_dataconfig(datafiles, dataconfigfile): |
1117 | 1147 | assert 'Provide two rhyming words for the word "love"' in output_inference |
1118 | 1148 |
|
1119 | 1149 |
|
| 1150 | +@pytest.mark.parametrize( |
| 1151 | + "datafiles, dataconfigfile", |
| 1152 | + [ |
| 1153 | + ( |
| 1154 | + [CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN, CHAT_DATA_SINGLE_TURN], |
| 1155 | + DATA_CONFIG_MULTITURN_DATA_YAML, |
| 1156 | + ) |
| 1157 | + ], |
| 1158 | +) |
| 1159 | +def test_run_chat_style_ft_using_dataconfig_for_chat_template( |
| 1160 | + datafiles, dataconfigfile |
| 1161 | +): |
| 1162 | + """Check if we can perform an e2e run with chat template |
| 1163 | + and multi turn chat training using data config.""" |
| 1164 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1165 | + |
| 1166 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1167 | + data_args.response_template = "<|assistant|>" |
| 1168 | + data_args.instruction_template = "<|user|>" |
| 1169 | + data_args.dataset_text_field = "new_formatted_field" |
| 1170 | + |
| 1171 | + handler_kwargs = {"dataset_text_field": data_args.dataset_text_field} |
| 1172 | + kwargs = { |
| 1173 | + "fn_kwargs": handler_kwargs, |
| 1174 | + "batched": False, |
| 1175 | + "remove_columns": "all", |
| 1176 | + } |
| 1177 | + |
| 1178 | + handler_config = DataHandlerConfig( |
| 1179 | + name="apply_tokenizer_chat_template", arguments=kwargs |
| 1180 | + ) |
| 1181 | + |
| 1182 | + model_args = copy.deepcopy(MODEL_ARGS) |
| 1183 | + model_args.tokenizer_name_or_path = CUSTOM_TOKENIZER_TINYLLAMA |
| 1184 | + |
| 1185 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1186 | + train_args.output_dir = tempdir |
| 1187 | + |
| 1188 | + with tempfile.NamedTemporaryFile( |
| 1189 | + "w", delete=False, suffix=".yaml" |
| 1190 | + ) as temp_yaml_file: |
| 1191 | + with open(dataconfigfile, "r", encoding="utf-8") as f: |
| 1192 | + data = yaml.safe_load(f) |
| 1193 | + datasets = data["datasets"] |
| 1194 | + for i, d in enumerate(datasets): |
| 1195 | + d["data_paths"] = [datafiles[i]] |
| 1196 | + # Basic chat datasets don't need data handling |
| 1197 | + d["data_handlers"] = [asdict(handler_config)] |
| 1198 | + yaml.dump(data, temp_yaml_file) |
| 1199 | + data_args.data_config_path = temp_yaml_file.name |
| 1200 | + |
| 1201 | + sft_trainer.train(model_args, data_args, train_args) |
| 1202 | + |
| 1203 | + # validate the configs |
| 1204 | + _validate_training(tempdir) |
| 1205 | + checkpoint_path = _get_checkpoint_path(tempdir) |
| 1206 | + |
| 1207 | + # Load the model |
| 1208 | + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) |
| 1209 | + |
| 1210 | + # Run inference on the text |
| 1211 | + output_inference = loaded_model.run( |
| 1212 | + '<|user|>\nProvide two rhyming words for the word "love"\n\ |
| 1213 | + <nopace></s><|assistant|>', |
| 1214 | + max_new_tokens=50, |
| 1215 | + ) |
| 1216 | + assert len(output_inference) > 0 |
| 1217 | + assert 'Provide two rhyming words for the word "love"' in output_inference |
| 1218 | + |
| 1219 | + |
1120 | 1220 | @pytest.mark.parametrize( |
1121 | 1221 | "data_args", |
1122 | 1222 | [ |
|
0 commit comments