|
41 | 41 | DATA_CONFIG_MULTIPLE_DATASETS_SAMPLING_YAML, |
42 | 42 | DATA_CONFIG_MULTITURN_DATA_YAML, |
43 | 43 | DATA_CONFIG_RENAME_RETAIN_COLUMNS, |
| 44 | + DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER, |
44 | 45 | DATA_CONFIG_TOKENIZE_AND_APPLY_INPUT_MASKING_YAML, |
| 46 | + DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER, |
45 | 47 | DATA_CONFIG_YAML_STREAMING_INPUT_OUTPUT, |
46 | 48 | DATA_CONFIG_YAML_STREAMING_PRETOKENIZED, |
47 | 49 | ) |
|
78 | 80 | DataPreProcessorConfig, |
79 | 81 | DataSetConfig, |
80 | 82 | ) |
81 | | -from tuning.data.data_handlers import add_tokenizer_eos_token |
| 83 | +from tuning.data.data_handlers import ( |
| 84 | + DataHandler, |
| 85 | + DataHandlerType, |
| 86 | + add_tokenizer_eos_token, |
| 87 | +) |
82 | 88 |
|
83 | 89 | MODEL_ARGS = configs.ModelArguments( |
84 | 90 | model_name_or_path=MODEL_NAME, use_flash_attn=False, torch_dtype="float32" |
@@ -321,14 +327,6 @@ def _get_training_logs_by_epoch(dir_path: str, epoch: int = None): |
321 | 327 | return data_list |
322 | 328 |
|
323 | 329 |
|
324 | | -def test_run_train_requires_output_dir(): |
325 | | - """Check fails when output dir not provided.""" |
326 | | - updated_output_dir_train_args = copy.deepcopy(TRAIN_ARGS) |
327 | | - updated_output_dir_train_args.output_dir = None |
328 | | - with pytest.raises(TypeError): |
329 | | - sft_trainer.train(MODEL_ARGS, DATA_ARGS, updated_output_dir_train_args, None) |
330 | | - |
331 | | - |
332 | 330 | def test_run_train_fails_training_data_path_not_exist(): |
333 | 331 | """Check fails when data path not found.""" |
334 | 332 | updated_data_path_args = copy.deepcopy(DATA_ARGS) |
@@ -996,6 +994,97 @@ def test_run_training_with_pretokenised_dataset_containing_input_ids(): |
996 | 994 | assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference |
997 | 995 |
|
998 | 996 |
|
| 997 | +def test_run_training_with_data_tokenized_using_tokenizer_handler(): |
| 998 | + """Ensure that we can train on non tokenized dataset works by tokenizing using |
| 999 | + tokenizer data handler via data config.""" |
| 1000 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1001 | + |
| 1002 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1003 | + |
| 1004 | + # set training_data_path and response_template to none |
| 1005 | + data_args.response_template = None |
| 1006 | + data_args.training_data_path = None |
| 1007 | + |
| 1008 | + dataconfigfile = DATA_CONFIG_TOKENIZE_AND_TRAIN_WITH_HANDLER |
| 1009 | + datapath = TWITTER_COMPLAINTS_DATA_JSONL |
| 1010 | + |
| 1011 | + # add data_paths in data_config file |
| 1012 | + with tempfile.NamedTemporaryFile( |
| 1013 | + "w", delete=False, suffix=".yaml" |
| 1014 | + ) as temp_yaml_file: |
| 1015 | + with open(dataconfigfile, "r", encoding="utf-8") as f: |
| 1016 | + data = yaml.safe_load(f) |
| 1017 | + datasets = data["datasets"] |
| 1018 | + for _, d in enumerate(datasets): |
| 1019 | + d["data_paths"] = [datapath] |
| 1020 | + yaml.dump(data, temp_yaml_file) |
| 1021 | + data_args.data_config_path = temp_yaml_file.name |
| 1022 | + |
| 1023 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1024 | + train_args.output_dir = tempdir |
| 1025 | + |
| 1026 | + sft_trainer.train(MODEL_ARGS, data_args, train_args) |
| 1027 | + |
| 1028 | + # validate full ft configs |
| 1029 | + _validate_training(tempdir) |
| 1030 | + checkpoint_path = _get_checkpoint_path(tempdir) |
| 1031 | + |
| 1032 | + # Load the model |
| 1033 | + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) |
| 1034 | + |
| 1035 | + # Run inference on the text |
| 1036 | + output_inference = loaded_model.run( |
| 1037 | + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 |
| 1038 | + ) |
| 1039 | + assert len(output_inference) > 0 |
| 1040 | + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference |
| 1041 | + |
| 1042 | + |
| 1043 | +def test_run_training_with_skip_large_text_handler(): |
| 1044 | + """Ensure that we can train succesfully after using skip large text handler.""" |
| 1045 | + with tempfile.TemporaryDirectory() as tempdir: |
| 1046 | + |
| 1047 | + data_args = copy.deepcopy(DATA_ARGS) |
| 1048 | + |
| 1049 | + # set training_data_path and response_template to none |
| 1050 | + data_args.response_template = None |
| 1051 | + data_args.training_data_path = None |
| 1052 | + |
| 1053 | + dataconfigfile = DATA_CONFIG_SKIP_LARGE_TEXT_HANDLER |
| 1054 | + datapath = TWITTER_COMPLAINTS_TOKENIZED_JSON |
| 1055 | + |
| 1056 | + # add data_paths in data_config file |
| 1057 | + with tempfile.NamedTemporaryFile( |
| 1058 | + "w", delete=False, suffix=".yaml" |
| 1059 | + ) as temp_yaml_file: |
| 1060 | + with open(dataconfigfile, "r", encoding="utf-8") as f: |
| 1061 | + data = yaml.safe_load(f) |
| 1062 | + datasets = data["datasets"] |
| 1063 | + for _, d in enumerate(datasets): |
| 1064 | + d["data_paths"] = [datapath] |
| 1065 | + yaml.dump(data, temp_yaml_file) |
| 1066 | + data_args.data_config_path = temp_yaml_file.name |
| 1067 | + |
| 1068 | + train_args = copy.deepcopy(TRAIN_ARGS) |
| 1069 | + train_args.output_dir = tempdir |
| 1070 | + |
| 1071 | + sft_trainer.train(MODEL_ARGS, data_args, train_args) |
| 1072 | + |
| 1073 | + # validate full ft configs |
| 1074 | + _validate_training(tempdir) |
| 1075 | + checkpoint_path = _get_checkpoint_path(tempdir) |
| 1076 | + |
| 1077 | + # Load the model |
| 1078 | + loaded_model = TunedCausalLM.load(checkpoint_path, MODEL_NAME) |
| 1079 | + |
| 1080 | + # Run inference on the text |
| 1081 | + output_inference = loaded_model.run( |
| 1082 | + "### Text: @NortonSupport Thanks much.\n\n### Label:", max_new_tokens=50 |
| 1083 | + ) |
| 1084 | + assert len(output_inference) > 0 |
| 1085 | + assert "### Text: @NortonSupport Thanks much.\n\n### Label:" in output_inference |
| 1086 | + |
| 1087 | + |
999 | 1088 | @pytest.mark.parametrize( |
1000 | 1089 | "dataset_path", |
1001 | 1090 | [CHAT_DATA_SINGLE_TURN, CHAT_DATA_MULTI_TURN], |
@@ -1656,7 +1745,8 @@ def test_run_with_bad_additional_data_handlers(additional_handlers): |
1656 | 1745 | train_args.output_dir = tempdir |
1657 | 1746 |
|
1658 | 1747 | with pytest.raises( |
1659 | | - ValueError, match="Handlers should be of type Dict, str to callable" |
| 1748 | + ValueError, |
| 1749 | + match="Handler should be of type tuning.data_handler.DataHandler, and name of str", |
1660 | 1750 | ): |
1661 | 1751 | sft_trainer.train( |
1662 | 1752 | MODEL_ARGS, |
@@ -1725,6 +1815,12 @@ def test_handler(element, tokenizer, **kwargs): |
1725 | 1815 | DATA_ARGS, |
1726 | 1816 | train_args, |
1727 | 1817 | PEFT_PT_ARGS, |
1728 | | - additional_data_handlers={TEST_HANDLER: test_handler}, |
| 1818 | + additional_data_handlers={ |
| 1819 | + TEST_HANDLER: DataHandler( |
| 1820 | + op=test_handler, |
| 1821 | + handler_type=DataHandlerType.MAP, |
| 1822 | + allows_batching=False, |
| 1823 | + ) |
| 1824 | + }, |
1729 | 1825 | ) |
1730 | 1826 | _validate_training(tempdir) |
0 commit comments