-
Notifications
You must be signed in to change notification settings - Fork 21
New options for preference tuning: rpo alpha, logprobs normalization, reference-free, simpo gamma #327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New options for preference tuning: rpo alpha, logprobs normalization, reference-free, simpo gamma #327
Changes from all commits
a0270e8
b92bc17
8e1ee62
7cd0109
0719212
5f8b188
57c0e16
d0a9932
4702194
a082aac
92d7e01
51b96fc
4724824
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api" | |
|
|
||
| [tool.poetry] | ||
| name = "together" | ||
| version = "1.5.13" | ||
| version = "1.5.14" | ||
| authors = ["Together AI <[email protected]>"] | ||
| description = "Python client for Together's Cloud Platform!" | ||
| readme = "README.md" | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,6 +72,9 @@ def create_finetune_request( | |
| train_on_inputs: bool | Literal["auto"] | None = None, | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| dpo_normalize_logratios_by_length: bool = False, | ||
| rpo_alpha: float | None = None, | ||
| simpo_gamma: float | None = None, | ||
| from_checkpoint: str | None = None, | ||
| ) -> FinetuneRequest: | ||
| if model is not None and from_checkpoint is not None: | ||
|
|
@@ -182,6 +185,21 @@ def create_finetune_request( | |
|
|
||
| if dpo_beta is not None and training_method != "dpo": | ||
| raise ValueError("dpo_beta is only supported for DPO training") | ||
| if dpo_normalize_logratios_by_length and training_method != "dpo": | ||
| raise ValueError( | ||
| "dpo_normalize_logratios_by_length=True is only supported for DPO training" | ||
| ) | ||
| if rpo_alpha is not None: | ||
| if training_method != "dpo": | ||
| raise ValueError("rpo_alpha is only supported for DPO training") | ||
| if not rpo_alpha >= 0.0: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it's wise to put an upper limit too
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure what can be a limit here, lets say 10? Wdyt?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure we should be enforcing any particular limit on this value, although it might be helpful. The problem is that this limit will apply only when users submit jobs via together-python |
||
| raise ValueError(f"rpo_alpha should be non-negative (got {rpo_alpha})") | ||
|
|
||
| if simpo_gamma is not None: | ||
| if training_method != "dpo": | ||
| raise ValueError("simpo_gamma is only supported for DPO training") | ||
| if not simpo_gamma >= 0.0: | ||
| raise ValueError(f"simpo_gamma should be non-negative (got {simpo_gamma})") | ||
|
|
||
| lr_scheduler: FinetuneLRScheduler | ||
| if lr_scheduler_type == "cosine": | ||
|
|
@@ -204,7 +222,24 @@ def create_finetune_request( | |
| if training_method == "sft": | ||
| training_method_cls = TrainingMethodSFT(train_on_inputs=train_on_inputs) | ||
| elif training_method == "dpo": | ||
| training_method_cls = TrainingMethodDPO(dpo_beta=dpo_beta) | ||
| if simpo_gamma is not None and simpo_gamma > 0: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. By the way, should we raise a ValueError if it's <=0?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added + added for rpo_alpha (can't imagine an use case for negative values for these parameters) |
||
| dpo_reference_free = True | ||
| dpo_normalize_logratios_by_length = True | ||
| rprint( | ||
| f"Parameter simpo_gamma was set to {simpo_gamma}. " | ||
| "SimPO training detected. Reference logits will not be used " | ||
| "and length normalization of log-probabilities will be enabled." | ||
| ) | ||
| else: | ||
| dpo_reference_free = False | ||
|
|
||
| training_method_cls = TrainingMethodDPO( | ||
| dpo_beta=dpo_beta, | ||
| dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length, | ||
| dpo_reference_free=dpo_reference_free, | ||
| rpo_alpha=rpo_alpha, | ||
| simpo_gamma=simpo_gamma, | ||
| ) | ||
|
|
||
| finetune_request = FinetuneRequest( | ||
| model=model, | ||
|
|
@@ -302,6 +337,9 @@ def create( | |
| train_on_inputs: bool | Literal["auto"] | None = None, | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| dpo_normalize_logratios_by_length: bool = False, | ||
| rpo_alpha: float | None = None, | ||
| simpo_gamma: float | None = None, | ||
| from_checkpoint: str | None = None, | ||
| ) -> FinetuneResponse: | ||
| """ | ||
|
|
@@ -353,6 +391,9 @@ def create( | |
| training_method (str, optional): Training method. Defaults to "sft". | ||
| Supported methods: "sft", "dpo". | ||
| dpo_beta (float, optional): DPO beta parameter. Defaults to None. | ||
| dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False, | ||
| rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None. | ||
| simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None. | ||
| from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job. | ||
| The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. | ||
| The step value is optional, without it the final checkpoint will be used. | ||
|
|
@@ -405,6 +446,9 @@ def create( | |
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method, | ||
| dpo_beta=dpo_beta, | ||
| dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length, | ||
| rpo_alpha=rpo_alpha, | ||
| simpo_gamma=simpo_gamma, | ||
| from_checkpoint=from_checkpoint, | ||
| ) | ||
|
|
||
|
|
@@ -714,6 +758,9 @@ async def create( | |
| train_on_inputs: bool | Literal["auto"] | None = None, | ||
| training_method: str = "sft", | ||
| dpo_beta: float | None = None, | ||
| dpo_normalize_logratios_by_length: bool = False, | ||
| rpo_alpha: float | None = None, | ||
| simpo_gamma: float | None = None, | ||
| from_checkpoint: str | None = None, | ||
| ) -> FinetuneResponse: | ||
| """ | ||
|
|
@@ -765,6 +812,9 @@ async def create( | |
| training_method (str, optional): Training method. Defaults to "sft". | ||
| Supported methods: "sft", "dpo". | ||
| dpo_beta (float, optional): DPO beta parameter. Defaults to None. | ||
| dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False, | ||
| rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None. | ||
| simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None. | ||
| from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job. | ||
| The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}. | ||
| The step value is optional, without it the final checkpoint will be used. | ||
|
|
@@ -817,6 +867,9 @@ async def create( | |
| train_on_inputs=train_on_inputs, | ||
| training_method=training_method, | ||
| dpo_beta=dpo_beta, | ||
| dpo_normalize_logratios_by_length=dpo_normalize_logratios_by_length, | ||
| rpo_alpha=rpo_alpha, | ||
| simpo_gamma=simpo_gamma, | ||
| from_checkpoint=from_checkpoint, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this could simply be
if rpo_alphaThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not quite, PEP8 explicitly advises against it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit below I want to notify user that rpo_alpha==0.0 throws an error