diff --git a/tests/unit/test_files_checks.py b/tests/unit/test_files_checks.py index 37c698d2..68af1f6c 100644 --- a/tests/unit/test_files_checks.py +++ b/tests/unit/test_files_checks.py @@ -303,3 +303,92 @@ def test_check_jsonl_empty_messages(tmp_path: Path): assert ( "Expected a non-empty list of messages. Found empty list" in report["message"] ) + + +def test_check_jsonl_valid_weights_all_messages(tmp_path: Path): + file = tmp_path / "valid_weights_all.jsonl" + content = [ + { + "messages": [ + {"role": "user", "content": "Hello", "weight": 1}, + {"role": "assistant", "content": "Hi there!", "weight": 0}, + {"role": "user", "content": "How are you?", "weight": 1}, + {"role": "assistant", "content": "I'm doing well!", "weight": 1}, + ] + }, + { + "messages": [ + {"role": "system", "content": "You are helpful", "weight": 0}, + {"role": "user", "content": "What's the weather?", "weight": 1}, + {"role": "assistant", "content": "It's sunny today!", "weight": 1}, + ] + }, + ] + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + assert report["is_check_passed"] + assert report["num_samples"] == len(content) + + +def test_check_jsonl_valid_weights_mixed_with_none(tmp_path: Path): + file = tmp_path / "valid_weights_mixed.jsonl" + content = [ + { + "messages": [ + {"role": "user", "content": "Hello", "weight": 1}, + {"role": "assistant", "content": "Hi there!", "weight": 0}, + {"role": "user", "content": "How are you?"}, + {"role": "assistant", "content": "I'm doing well!"}, + ] + }, + { + "messages": [ + {"role": "user", "content": "What's the weather?"}, + {"role": "assistant", "content": "It's sunny today!"}, + ] + }, + ] + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + assert report["is_check_passed"] + assert report["num_samples"] == len(content) + + +def test_check_jsonl_invalid_weight_float(tmp_path: Path): + file = tmp_path / "invalid_weight_float.jsonl" + content = [ + { + "messages": [ + {"role": "user", "content": "Hello", "weight": 1.0}, + {"role": "assistant", "content": "Hi there!", "weight": 0}, + ] + } + ] + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + assert not report["is_check_passed"] + assert "Weight must be an integer" in report["message"] + + +def test_check_jsonl_invalid_weight(tmp_path: Path): + file = tmp_path / "invalid_weight.jsonl" + content = [ + { + "messages": [ + {"role": "user", "content": "Hello", "weight": 2}, + {"role": "assistant", "content": "Hi there!", "weight": 0}, + ] + } + ] + with file.open("w") as f: + f.write("\n".join(json.dumps(item) for item in content)) + + report = check_file(file) + assert not report["is_check_passed"] + assert "Weight must be either 0 or 1" in report["message"]