Skip to content
Merged
42 changes: 38 additions & 4 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
DownloadCheckpointType,
FinetuneEventType,
FinetuneTrainingLimits,
FullTrainingType,
LoRATrainingType,
)
from together.utils import (
finetune_price_to_dollars,
Expand All @@ -29,13 +31,21 @@

_CONFIRMATION_MESSAGE = (
"You are about to create a fine-tuning job. "
"The cost of your job will be determined by the model size, the number of tokens "
"The estimated price of this job is {price}. "
"The actual cost of your job will be determined by the model size, the number of tokens "
"in the training file, the number of tokens in the validation file, the number of epochs, and "
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
"the number of evaluations. Visit https://www.together.ai/pricing to learn more about pricing.\n"
"{warning}"
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
"Do you want to proceed?"
)

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of this job is significantly greater than your current credit limit and balance combined. "
"It will likely fail due to insufficient funds. "
"Please consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
)


class DownloadCheckpointTypeChoice(click.Choice):
def __init__(self) -> None:
Expand Down Expand Up @@ -357,12 +367,36 @@ def create(
"You have specified a number of evaluation loops but no validation file."
)

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
finetune_price_estimation_result = client.fine_tuning.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

price = click.style(
f"${finetune_price_estimation_result.estimated_total_price:.2f}",
bold=True,
)

if not finetune_price_estimation_result.allowed_to_proceed:
warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
else:
warning = ""

confirmation_message = _CONFIRMATION_MESSAGE.format(
price=price,
warning=warning,
)

if confirm or click.confirm(confirmation_message, default=True, show_default=True):
response = client.fine_tuning.create(
**training_args,
verbose=True,
)

report_string = f"Successfully submitted a fine-tuning job {response.id}"
if response.created_at is not None:
created_time = datetime.strptime(
Expand Down
159 changes: 158 additions & 1 deletion src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
FinetuneLRScheduler,
FinetuneRequest,
FinetuneResponse,
FinetunePriceEstimationRequest,
FinetunePriceEstimationResponse,
FinetuneTrainingLimits,
FullTrainingType,
LinearLRScheduler,
Expand All @@ -31,7 +33,7 @@
TrainingMethodSFT,
TrainingType,
)
from together.types.finetune import DownloadCheckpointType
from together.types.finetune import DownloadCheckpointType, TrainingMethod
from together.utils import log_warn_once, normalize_key


Expand All @@ -42,6 +44,12 @@
TrainingMethodSFT().method,
TrainingMethodDPO().method,
}
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of the fine-tuning job is {} which is significantly "
"greater than your current credit limit and balance combined. "
"It will likely fail due to insufficient funds. "
"Please proceed at your own risk."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"It will likely fail due to insufficient funds. "
"Please proceed at your own risk."
"It will likely get cancelled due to insufficient funds. "
"Proceed at your own risk."

)


def create_finetune_request(
Expand Down Expand Up @@ -474,11 +482,29 @@ def create(
hf_output_repo_name=hf_output_repo_name,
)

price_estimation_result = self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model_name,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_estimation_result.allowed_to_proceed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

response, _, _ = requestor.request(
Expand All @@ -493,6 +519,75 @@ def create(

return FinetuneResponse(**response.data)

def estimate_price(
self,
*,
training_file: str,
model: str | None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these fields be none?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They can in principle, but the price estimation API doesn't seem to support it. Therefore, we need to fix the API, disable calling the estimation when model is None, but set this field to always be defined (i.e., remove the None option)

validation_file: str | None = None,
n_epochs: int | None = None,
n_evals: int | None = None,
training_type: str = "lora",
training_method: str = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimates the price of a fine-tuning job

Args:
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring is not consistent with the actual arguments


Returns:
FinetunePriceEstimationResponse: Object containing the estimated price.
"""
training_type_cls: TrainingType | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to set them None here, as they're defined below in all branches (or exception). You can keep type definition if you want to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy complains

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a bit weird, all the possible branches are covered below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it complain when you don't define the type or when you don't set it to None?

training_method_cls: TrainingMethod | None = None

if training_method == "sft":
training_method_cls = TrainingMethodSFT(method="sft")
elif training_method == "dpo":
training_method_cls = TrainingMethodDPO(method="dpo")
else:
raise ValueError(f"Unknown training method: {training_method}")

if training_type.lower() == "lora":
# parameters of lora are unused in price estimation
# but we need to set them to valid values
training_type_cls = LoRATrainingType(
type="Lora",
lora_r=16,
lora_alpha=16,
lora_dropout=0.0,
lora_trainable_modules="all-linear",
)
elif training_type.lower() == "full":
training_type_cls = FullTrainingType(type="Full")
else:
raise ValueError(f"Unknown training type: {training_type}")

request = FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)
parameter_payload = request.model_dump(exclude_none=True)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = requestor.request(
options=TogetherRequest(
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
),
stream=False,
)
assert isinstance(response, TogetherResponse)

return FinetunePriceEstimationResponse(**response.data)

def list(self) -> FinetuneList:
"""
Lists fine-tune job history
Expand Down Expand Up @@ -941,11 +1036,29 @@ async def create(
hf_output_repo_name=hf_output_repo_name,
)

price_estimation_result = await self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model_name,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=finetune_request.training_type,
training_method=finetune_request.training_method,
)

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_estimation_result.allowed_to_proceed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

response, _, _ = await requestor.arequest(
Expand All @@ -961,6 +1074,50 @@ async def create(

return FinetuneResponse(**response.data)

async def estimate_price(
self,
*,
training_file: str,
model: str,
validation_file: str | None = None,
n_epochs: int | None = None,
n_evals: int | None = None,
training_type: TrainingType | None = None,
training_method: TrainingMethodSFT | TrainingMethodDPO | None = None,
) -> FinetunePriceEstimationResponse:
"""
Async method to estimate the price of a fine-tuning job

Args:
request (FinetunePriceEstimationRequest): Request object containing the parameters for the price estimation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


Returns:
FinetunePriceEstimationResponse: Object containing the estimated price.
"""
request = FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type,
training_method=training_method,
)
parameter_payload = request.model_dump(exclude_none=True)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
),
stream=False,
)
assert isinstance(response, TogetherResponse)

return FinetunePriceEstimationResponse(**response.data)

async def list(self) -> FinetuneList:
"""
Async method to list fine-tune job history
Expand Down
4 changes: 4 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
FinetuneListEvents,
FinetuneRequest,
FinetuneResponse,
FinetunePriceEstimationRequest,
FinetunePriceEstimationResponse,
FinetuneDeleteResponse,
FinetuneTrainingLimits,
FullTrainingType,
Expand Down Expand Up @@ -103,6 +105,8 @@
"FinetuneDeleteResponse",
"FinetuneDownloadResult",
"FinetuneLRScheduler",
"FinetunePriceEstimationRequest",
"FinetunePriceEstimationResponse",
"LinearLRScheduler",
"LinearLRSchedulerArgs",
"CosineLRScheduler",
Expand Down
26 changes: 26 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType:
raise ValueError("Unknown training type")


class FinetunePriceEstimationRequest(BaseModel):
"""
Fine-tune price estimation request type
"""

training_file: str
validation_file: str | None = None
model: str
n_epochs: int | None = None
n_evals: int | None = None
training_type: TrainingType | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be none? Same goes for n_epoch and n_evals?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be None. If I implement it in any other way, it would be an unrelated change -- I'd just need to change it for all other data types for FT

training_method: TrainingMethodSFT | TrainingMethodDPO


class FinetunePriceEstimationResponse(BaseModel):
"""
Fine-tune price estimation response type
"""

estimated_total_price: float
user_limit: float
estimated_train_token_count: int
estimated_eval_token_count: int
allowed_to_proceed: bool


class FinetuneList(BaseModel):
# object type
object: Literal["list"] | None = None
Expand Down