-
Notifications
You must be signed in to change notification settings - Fork 21
Refactor and update file checks logic #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 12 commits
2035ffb
6c50c1f
2c14af0
1a261d0
4a7797e
dfea707
f11307e
775167b
88c54fc
0aae8f2
4fd70eb
b08bdda
bbc32bc
2a1efd1
c79aca1
93e580a
92e4748
92be6d7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -102,81 +102,163 @@ def check_file( | |
| return report_dict | ||
|
|
||
|
|
||
| def validate_messages(messages: List[Dict[str, str | bool]], idx: int) -> None: | ||
| """Validate the messages column.""" | ||
| def _check_conversation_type(messages: List[Dict[str, str | bool]], idx: int) -> None: | ||
| """Check that the conversation has correct type. | ||
|
|
||
| Args: | ||
| messages: The messages in the conversation. | ||
| Can be any type, this function ensures that the messages are a list of dictionaries. | ||
| idx: Line number in the file. | ||
|
|
||
| Raises: | ||
| InvalidFileFormatError: If the conversation type is invalid. | ||
| """ | ||
| if not isinstance(messages, list): | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid format on line {idx + 1} of the input file. " | ||
| f"Expected a list of messages. Found {type(messages)}", | ||
| f"The `messages` column must be a list. Found {type(messages)}", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| if not messages: | ||
| if len(messages) == 0: | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid format on line {idx + 1} of the input file. " | ||
| f"Expected a non-empty list of messages. Found empty list", | ||
| f"The `messages` column must not be empty.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| has_weights = any("weight" in message for message in messages) | ||
|
|
||
| previous_role = None | ||
| for message in messages: | ||
| if not isinstance(message, dict): | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid format on line {idx + 1} of the input file. " | ||
| f"Expected a dictionary in the messages list. Found {type(message)}", | ||
| f"The `messages` column must be a list of dicts. Found {type(message)}", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| for column in REQUIRED_COLUMNS_MESSAGE: | ||
| if column not in message: | ||
| raise InvalidFileFormatError( | ||
| message=f"Field `{column}` is missing for a turn `{message}` on line {idx + 1} " | ||
| "of the the input file.", | ||
| message=f"Missing required column `{column}` in message on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| else: | ||
| if not isinstance(message[column], str): | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid format on line {idx + 1} in the column {column} for turn `{message}` " | ||
| f"of the input file. Expected string. Found {type(message[column])}", | ||
| line_number=idx + 1, | ||
| error_source="text_field", | ||
| ) | ||
|
|
||
| if has_weights and "weight" in message: | ||
| weight = message["weight"] | ||
| if not isinstance(weight, int): | ||
| raise InvalidFileFormatError( | ||
| message="Weight must be an integer", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| if weight not in {0, 1}: | ||
| if not isinstance(message[column], str): | ||
| raise InvalidFileFormatError( | ||
| message="Weight must be either 0 or 1", | ||
| message=f"Column `{column}` is not a string on line {idx + 1}. Found {type(message[column])}", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| error_source="text_field", | ||
| ) | ||
| if message["role"] not in POSSIBLE_ROLES_CONVERSATION: | ||
|
|
||
|
|
||
| def _check_conversation_roles( | ||
| require_assistant_role: bool, assistant_role_exists: bool, idx: int | ||
| ) -> None: | ||
| """Check that the conversation has correct roles. | ||
|
|
||
| Args: | ||
| require_assistant_role: Whether to require at least one assistant role. | ||
| assistant_role_exists: Whether an assistant role exists in the conversation. | ||
| idx: Line number in the file. | ||
|
|
||
| Raises: | ||
| InvalidFileFormatError: If the conversation roles are invalid. | ||
| """ | ||
| if require_assistant_role and not assistant_role_exists: | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid format on line {idx + 1} of the input file. " | ||
| "At least one message with the assistant role must be present in the example.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
|
|
||
| def _check_message_weight(message: Dict[str, str | bool], idx: int) -> None: | ||
| """Check that the message has a weight with the correct type and value. | ||
|
|
||
| Args: | ||
| message: The message to check. | ||
| idx: Line number in the file. | ||
|
|
||
| Raises: | ||
| InvalidFileFormatError: If the message weight is invalid. | ||
| """ | ||
| if "weight" in message: | ||
| weight = message["weight"] | ||
| if not isinstance(weight, int): | ||
| raise InvalidFileFormatError( | ||
| message=f"Found invalid role `{message['role']}` in the messages on the line {idx + 1}. " | ||
| f"Possible roles in the conversation are: {POSSIBLE_ROLES_CONVERSATION}", | ||
| message=f"Weight must be an integer on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| if previous_role == message["role"]: | ||
| if weight not in {0, 1}: | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid role turns on line {idx + 1} of the input file. " | ||
| "`user` and `assistant` roles must alternate user/assistant/user/assistant/...", | ||
| message=f"Weight must be either 0 or 1 on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| previous_role = message["role"] | ||
|
|
||
|
|
||
| def _check_message_role( | ||
| message: Dict[str, str | bool], previous_role: str | None, idx: int | ||
| ) -> str | bool: | ||
| """Check that the message has correct roles. | ||
|
|
||
| Args: | ||
| message: The message to check. | ||
| previous_role: The role of the previous message. | ||
| idx: Line number in the file. | ||
|
|
||
| Returns: | ||
| str: The role of the current message. | ||
|
|
||
| Raises: | ||
| InvalidFileFormatError: If the message role is invalid. | ||
| """ | ||
| if message["role"] not in POSSIBLE_ROLES_CONVERSATION: | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid role `{message['role']}` in conversation on line {idx + 1}. " | ||
| f"Possible roles: {', '.join(POSSIBLE_ROLES_CONVERSATION)}", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| if previous_role is not None and message["role"] == previous_role: | ||
| raise InvalidFileFormatError( | ||
| message=f"Invalid role turns on line {idx + 1} of the input file. " | ||
| "After the optional system message, conversation roles must alternate between user/assistant/user/assistant.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| return message["role"] | ||
|
|
||
|
|
||
| def validate_messages( | ||
| messages: List[Dict[str, str | bool]], idx: int, require_assistant_role: bool = True | ||
| ) -> None: | ||
| """Validate the messages column. | ||
|
|
||
| Args: | ||
| messages: List of message dictionaries to validate. | ||
| idx: Line number in the file. | ||
| require_assistant_role: Whether to require at least one assistant role. | ||
|
|
||
| Raises: | ||
| InvalidFileFormatError: If the messages are invalid. | ||
| """ | ||
| _check_conversation_type(messages, idx) | ||
|
|
||
| has_weights = any("weight" in message for message in messages) | ||
| previous_role = None | ||
| assistant_role_exists = False | ||
|
|
||
| for message in messages: | ||
| if has_weights: | ||
| _check_message_weight(message, idx) | ||
| previous_role = _check_message_role(message, previous_role, idx) | ||
| assistant_role_exists |= previous_role == "assistant" | ||
|
|
||
| _check_conversation_roles(require_assistant_role, assistant_role_exists, idx) | ||
|
|
||
|
|
||
| def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: | ||
|
|
@@ -203,37 +285,73 @@ def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> None: | |
| error_source="key_value", | ||
| ) | ||
|
|
||
| validate_messages(example["input"]["messages"], idx) | ||
| validate_messages(example["input"]["messages"], idx, require_assistant_role=False) | ||
|
|
||
| if example["input"]["messages"][-1]["role"] == "assistant": | ||
| raise InvalidFileFormatError( | ||
| message=f"The last message in the input conversation must not be from the assistant on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| keys = ["preferred_output", "non_preferred_output"] | ||
|
|
||
| for key in keys: | ||
| if key not in example: | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{key}` field must be present in the input dictionary on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| if not isinstance(example[key], list): | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{key}` field must be a list on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| if len(example[key]) != 1: | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{key}` list must contain exactly one message on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| for output_field in ["preferred_output", "non_preferred_output"]: | ||
| if not isinstance(example[output_field], list): | ||
| if not isinstance(example[key][0], dict): | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{output_field}` field must be a list.", | ||
| message=f"The dataset is malformed, the first element of `{key}` must be a dictionary on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| if len(example[output_field]) != 1: | ||
| if "role" not in example[key][0]: | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{output_field}` list must contain exactly one message.", | ||
| message=f"The dataset is malformed, the first element of `{key}` must have a 'role' field on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| if "role" not in example[output_field][0]: | ||
|
|
||
| if example[key][0]["role"] != "assistant": | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{output_field}` message is missing the `role` field.", | ||
| message=f"The dataset is malformed, the first element of `{key}` must have the 'assistant' role on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
| elif example[output_field][0]["role"] != "assistant": | ||
|
|
||
| if "content" not in example[key][0]: | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the `{output_field}` must contain an assistant message.", | ||
| message=f"The dataset is malformed, the first element of `{key}` must have a 'content' field on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
|
||
| validate_messages(example["preferred_output"], idx) | ||
| validate_messages(example["non_preferred_output"], idx) | ||
| if not isinstance(example[key][0]["content"], str): | ||
| raise InvalidFileFormatError( | ||
| message=f"The dataset is malformed, the 'content' field in `{key}` must be a string on line {idx + 1}.", | ||
| line_number=idx + 1, | ||
| error_source="key_value", | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
|
|
||
| def _check_utf8(file: Path) -> Dict[str, Any]: | ||
|
|
@@ -410,7 +528,12 @@ def _check_jsonl(file: Path, purpose: FilePurpose | str) -> Dict[str, Any]: | |
| message_column = JSONL_REQUIRED_COLUMNS_MAP[ | ||
| DatasetFormat.CONVERSATION | ||
| ][0] | ||
| validate_messages(json_line[message_column], idx) | ||
| require_assistant = purpose != FilePurpose.Eval | ||
| validate_messages( | ||
| json_line[message_column], | ||
| idx, | ||
| require_assistant_role=require_assistant, | ||
| ) | ||
| else: | ||
| for column in JSONL_REQUIRED_COLUMNS_MAP[current_format]: | ||
| if not isinstance(json_line[column], str): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -182,7 +182,12 @@ def test_check_jsonl_inconsistent_dataset_format(tmp_path: Path): | |
| # Create a JSONL file with inconsistent dataset formats | ||
| file = tmp_path / "inconsistent_format.jsonl" | ||
| content = [ | ||
| {"messages": [{"role": "user", "content": "Hi"}]}, | ||
| { | ||
| "messages": [ | ||
| {"role": "user", "content": "Hi"}, | ||
| {"role": "assistant", "content": "Hi! How can I help you?"}, | ||
| ] | ||
| }, | ||
| {"text": "How are you?"}, # Missing 'messages' | ||
| ] | ||
| with file.open("w") as f: | ||
|
|
@@ -207,7 +212,7 @@ def test_check_jsonl_invalid_role(tmp_path: Path): | |
| report = check_file(file) | ||
|
|
||
| assert not report["is_check_passed"] | ||
| assert "Found invalid role `invalid_role`" in report["message"] | ||
| assert "Invalid role `invalid_role` in conversation" in report["message"] | ||
cursor[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def test_check_jsonl_non_alternating_roles(tmp_path: Path): | ||
|
|
@@ -230,6 +235,22 @@ def test_check_jsonl_non_alternating_roles(tmp_path: Path): | |
| assert "Invalid role turns" in report["message"] | ||
|
|
||
|
|
||
| def test_check_jsonl_non_alternating_roles(tmp_path: Path): | ||
| # Create a JSONL file with non-alternating user/assistant roles | ||
| file = tmp_path / "non_alternating_roles.jsonl" | ||
| content = [{"messages": [{"role": "user", "content": "Hi"}]}] | ||
| with file.open("w") as f: | ||
| f.write("\n".join(json.dumps(item) for item in content)) | ||
|
|
||
| report = check_file(file) | ||
|
|
||
| assert not report["is_check_passed"] | ||
| assert ( | ||
| "At least one message with the assistant role must be present" | ||
| in report["message"] | ||
| ) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bug: Duplicate Test Overwrites Original CaseThe |
||
|
|
||
| def test_check_jsonl_invalid_value_type(tmp_path: Path): | ||
| # Create a JSONL file with an invalid value type | ||
| file = tmp_path / "invalid_value_type.jsonl" | ||
|
|
@@ -257,7 +278,7 @@ def test_check_jsonl_missing_field_in_conversation(tmp_path: Path): | |
|
|
||
| report = check_file(file) | ||
| assert not report["is_check_passed"] | ||
| assert "Field `content` is missing for a turn" in report["message"] | ||
| assert "Missing required column `content`" in report["message"] | ||
|
|
||
|
|
||
| def test_check_jsonl_wrong_turn_type(tmp_path: Path): | ||
|
|
@@ -277,7 +298,7 @@ def test_check_jsonl_wrong_turn_type(tmp_path: Path): | |
| report = check_file(file) | ||
| assert not report["is_check_passed"] | ||
| assert ( | ||
| "Invalid format on line 1 of the input file. Expected a dictionary" | ||
| "Invalid format on line 1 of the input file. The `messages` column must be a list of dicts." | ||
| in report["message"] | ||
| ) | ||
|
|
||
|
|
@@ -301,9 +322,7 @@ def test_check_jsonl_empty_messages(tmp_path: Path): | |
|
|
||
| report = check_file(file) | ||
| assert not report["is_check_passed"] | ||
| assert ( | ||
| "Expected a non-empty list of messages. Found empty list" in report["message"] | ||
| ) | ||
| assert "The `messages` column must not be empty" in report["message"] | ||
|
|
||
|
|
||
| def test_check_jsonl_valid_weights_all_messages(tmp_path: Path): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bug: Inconsistent Return Type in
_check_message_roleThe
_check_message_rolefunction's return type annotation isstr | bool, but it consistently returnsmessage["role"], which is always a string.