-
Notifications
You must be signed in to change notification settings - Fork 21
Add support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO) #262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/together/utils/files.py
Outdated
| filtered_messages.append( | ||
| {column: message[column] for column in REQUIRED_COLUMNS_MESSAGE} | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I'm not sure if filtering files when they are uploaded is the right solution: this will require users to reupload their data whenever we support a new field for messages (for example, function calling)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, removed the filtering part from the function
src/together/utils/files.py
Outdated
| ) | ||
|
|
||
| if not isinstance(example["preferred_output"], list): | ||
| raise ValueError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these should be InvalidFileFormatError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
Co-authored-by: Max Ryabinin <[email protected]>
Co-authored-by: Max Ryabinin <[email protected]>
src/together/types/finetune.py
Outdated
| Training method type for SFT training | ||
| """ | ||
|
|
||
| method: str = "sft" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| method: str = "sft" | |
| method: Literal["sft"] = "sft" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added
src/together/utils/files.py
Outdated
|
|
||
| has_weights = False | ||
| # Check for weights in messages | ||
| if _has_weights(messages): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you make this into a separate function? Why not to inline it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can even be like
has_weights = any("weight" in message for message in messages)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
src/together/utils/files.py
Outdated
| ) | ||
| previous_role = message["role"] | ||
|
|
||
| return messages, has_weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to return messages? The row doesn't seem to be modified.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/together/utils/files.py
Outdated
| return messages, has_weights | ||
|
|
||
|
|
||
| def validate_preference_openai(example: Dict[str, Any], idx: int = 0) -> Dict[str, Any]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to return an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
src/together/resources/finetune.py
Outdated
| lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
| ) | ||
|
|
||
| training_method_cls: Union[TrainingMethodSFT, TrainingMethodDPO] = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: since you're using the | notation to specify union types above, I would use it here as well and remove the redundant import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
src/together/utils/files.py
Outdated
| has_weights = False | ||
| # Check for weights in messages | ||
| if _has_weights(messages): | ||
| has_weights = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it just the following? :)
| has_weights = False | |
| # Check for weights in messages | |
| if _has_weights(messages): | |
| has_weights = True | |
| has_weights = _has_weights(messages) |
src/together/utils/files.py
Outdated
|
|
||
|
|
||
| def validate_messages( | ||
| messages: List[Dict[str, str | bool]], idx: int = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to imagine a case where we would want to use the default line number, maybe it's best to remove the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
src/together/utils/files.py
Outdated
| example["input"]["messages"], _ = validate_messages( | ||
| example["input"]["messages"], idx | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't modify anything in messages, I would simply make validate_messages return nothing and raise an exception in case of an error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
tests/unit/test_files_checks.py
Outdated
|
|
||
| def test_check_jsonl_invalid_preference_openai_structural_issues(tmp_path: Path): | ||
| # Test various structural issues in OpenAI preference format | ||
| test_cases = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use pytest.mark.parametrize for iterating over multiple test cases
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| def test_check_jsonl_valid_conversational_multiple_turns(tmp_path: Path): | ||
| # Create a valid JSONL file with conversational format and multiple user-assistant turn pairs | ||
| file = tmp_path / "valid_conversational_multiple_turns.jsonl" | ||
| content = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer to keep the current file for this test and write a new one for , because
- Unit tests should test orthogonal capabilities, otherwise this gets misleading when an error is introduced (improper parsing of preference data should not affect tests for regular conversation datasets)
- Right now, it actually looks like this test is now identical to
test_check_jsonl_valid_preference_openai, which is unlikely to be what you want :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created a separate file
src/together/resources/finetune.py
Outdated
| AVAILABLE_TRAINING_METHODS = { | ||
| TrainingMethodSFT().method, | ||
| TrainingMethodDPO().method, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a constant, can you move it to the top of the file (outside of the function and the class definition)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
| lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
| ) | ||
|
|
||
| training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: maybe annotate the type as training_method_cls: TrainingMethod? It's a bit clearer and more extensible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There were some issues with pre-commit checks when I tried to do this, as I remember
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future
tests/unit/test_preference_openai.py
Outdated
| assert report["has_min_samples"] | ||
|
|
||
|
|
||
| # Define test cases for missing fields |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment seems redundant
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
tests/unit/test_preference_openai.py
Outdated
| from together.constants import MIN_SAMPLES | ||
| from together.utils.files import check_file | ||
|
|
||
| # Test data for preference OpenAI format |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one's also not very informative given the name of the variable
tests/unit/test_preference_openai.py
Outdated
| assert not report["is_check_passed"], f"Test should fail when {description}" | ||
|
|
||
|
|
||
| # Define test cases for structural issues |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
| lr_scheduler_args=FinetuneLinearLRSchedulerArgs(min_lr_ratio=min_lr_ratio), | ||
| ) | ||
|
|
||
| training_method_cls: TrainingMethodSFT | TrainingMethodDPO = TrainingMethodSFT() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Weird, do you remember what was the error by any chance? Not blocking, but I'd love to know how to fix it in the future
| assert not report["is_check_passed"], f"Test should fail when {description}" | ||
|
|
||
|
|
||
| STRUCTURAL_ISSUE_TEST_CASES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the constant can be made private
| assert report["has_min_samples"] | ||
|
|
||
|
|
||
| MISSING_FIELDS_TEST_CASES = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: the constant can be made private
Describe your changes
This PR adds support for the Training Method for finetuning, and for Direct-Preference Optimization (DPO).