Skip to content

Commit eba5e5f

Browse files
VLM Finetuning support (#411)
* Support Multimodal datasets * Support VLM finetuning
1 parent 137b84e commit eba5e5f

File tree

9 files changed

+402
-89
lines changed

9 files changed

+402
-89
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ build-backend = "poetry.masonry.api"
1212

1313
[tool.poetry]
1414
name = "together"
15-
version = "1.5.33"
15+
version = "1.5.34"
1616
authors = ["Together AI <[email protected]>"]
1717
description = "Python client for Together's Cloud Platform! Note: SDK 2.0 is now available at https://github.com/togethercomputer/together-py"
1818
readme = "README.md"

src/together/cli/api/finetune.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from __future__ import annotations
22

33
import json
4-
import re
54
from datetime import datetime, timezone
65
from textwrap import wrap
76
from typing import Any, Literal
@@ -14,18 +13,11 @@
1413

1514
from together import Together
1615
from together.cli.api.utils import BOOL_WITH_AUTO, INT_WITH_MAX, generate_progress_bar
17-
from together.types.finetune import (
18-
DownloadCheckpointType,
19-
FinetuneEventType,
20-
FinetuneTrainingLimits,
21-
FullTrainingType,
22-
LoRATrainingType,
23-
)
16+
from together.types.finetune import DownloadCheckpointType, FinetuneTrainingLimits
2417
from together.utils import (
2518
finetune_price_to_dollars,
2619
format_timestamp,
2720
log_warn,
28-
log_warn_once,
2921
parse_timestamp,
3022
)
3123

@@ -203,6 +195,12 @@ def fine_tuning(ctx: click.Context) -> None:
203195
help="Whether to mask the user messages in conversational data or prompts in instruction data. "
204196
"`auto` will automatically determine whether to mask the inputs based on the data format.",
205197
)
198+
@click.option(
199+
"--train-vision",
200+
type=bool,
201+
default=False,
202+
help="Whether to train the vision encoder. Only supported for multimodal models.",
203+
)
206204
@click.option(
207205
"--from-checkpoint",
208206
type=str,
@@ -258,6 +256,7 @@ def create(
258256
lora_dropout: float,
259257
lora_alpha: float,
260258
lora_trainable_modules: str,
259+
train_vision: bool,
261260
suffix: str,
262261
wandb_api_key: str,
263262
wandb_base_url: str,
@@ -299,6 +298,7 @@ def create(
299298
lora_dropout=lora_dropout,
300299
lora_alpha=lora_alpha,
301300
lora_trainable_modules=lora_trainable_modules,
301+
train_vision=train_vision,
302302
suffix=suffix,
303303
wandb_api_key=wandb_api_key,
304304
wandb_base_url=wandb_base_url,
@@ -368,6 +368,10 @@ def create(
368368
"You have specified a number of evaluation loops but no validation file."
369369
)
370370

371+
if model_limits.supports_vision:
372+
# Don't show price estimation for multimodal models yet
373+
confirm = True
374+
371375
finetune_price_estimation_result = client.fine_tuning.estimate_price(
372376
training_file=training_file,
373377
validation_file=validation_file,

src/together/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import enum
22

3+
34
# Session constants
45
TIMEOUT_SECS = 600
56
MAX_SESSION_LIFETIME_SECS = 180
@@ -40,6 +41,11 @@
4041
# the number of bytes in a gigabyte, used to convert bytes to GB for readable comparison
4142
NUM_BYTES_IN_GB = 2**30
4243

44+
# Multimodal limits
45+
MAX_IMAGES_PER_EXAMPLE = 10
46+
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
47+
# Max length = Header length + base64 factor (4/3) * image bytes
48+
MAX_BASE64_IMAGE_LENGTH = len("data:image/jpeg;base64,") + 4 * MAX_IMAGE_BYTES // 3
4349

4450
# expected columns for Parquet files
4551
PARQUET_EXPECTED_COLUMNS = ["input_ids", "attention_mask", "labels"]

src/together/resources/finetune.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from pathlib import Path
5-
from typing import List, Dict, Literal
5+
from typing import Dict, List, Literal
66

77
from rich import print as rprint
88

@@ -18,10 +18,11 @@
1818
FinetuneList,
1919
FinetuneListEvents,
2020
FinetuneLRScheduler,
21-
FinetuneRequest,
22-
FinetuneResponse,
21+
FinetuneMultimodalParams,
2322
FinetunePriceEstimationRequest,
2423
FinetunePriceEstimationResponse,
24+
FinetuneRequest,
25+
FinetuneResponse,
2526
FinetuneTrainingLimits,
2627
FullTrainingType,
2728
LinearLRScheduler,
@@ -73,6 +74,7 @@ def create_finetune_request(
7374
lora_dropout: float | None = 0,
7475
lora_alpha: float | None = None,
7576
lora_trainable_modules: str | None = "all-linear",
77+
train_vision: bool = False,
7678
suffix: str | None = None,
7779
wandb_api_key: str | None = None,
7880
wandb_base_url: str | None = None,
@@ -252,6 +254,15 @@ def create_finetune_request(
252254
simpo_gamma=simpo_gamma,
253255
)
254256

257+
if model_limits.supports_vision:
258+
multimodal_params = FinetuneMultimodalParams(train_vision=train_vision)
259+
elif not model_limits.supports_vision and train_vision:
260+
raise ValueError(
261+
f"Vision encoder training is not supported for the non-multimodal model `{model}`"
262+
)
263+
else:
264+
multimodal_params = None
265+
255266
finetune_request = FinetuneRequest(
256267
model=model,
257268
training_file=training_file,
@@ -272,6 +283,7 @@ def create_finetune_request(
272283
wandb_project_name=wandb_project_name,
273284
wandb_name=wandb_name,
274285
training_method=training_method_cls,
286+
multimodal_params=multimodal_params,
275287
from_checkpoint=from_checkpoint,
276288
from_hf_model=from_hf_model,
277289
hf_model_revision=hf_model_revision,
@@ -342,6 +354,7 @@ def create(
342354
lora_dropout: float | None = 0,
343355
lora_alpha: float | None = None,
344356
lora_trainable_modules: str | None = "all-linear",
357+
train_vision: bool = False,
345358
suffix: str | None = None,
346359
wandb_api_key: str | None = None,
347360
wandb_base_url: str | None = None,
@@ -387,6 +400,7 @@ def create(
387400
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
388401
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
389402
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
403+
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
390404
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
391405
Defaults to None.
392406
wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -464,6 +478,7 @@ def create(
464478
lora_dropout=lora_dropout,
465479
lora_alpha=lora_alpha,
466480
lora_trainable_modules=lora_trainable_modules,
481+
train_vision=train_vision,
467482
suffix=suffix,
468483
wandb_api_key=wandb_api_key,
469484
wandb_base_url=wandb_base_url,
@@ -906,6 +921,7 @@ async def create(
906921
lora_dropout: float | None = 0,
907922
lora_alpha: float | None = None,
908923
lora_trainable_modules: str | None = "all-linear",
924+
train_vision: bool = False,
909925
suffix: str | None = None,
910926
wandb_api_key: str | None = None,
911927
wandb_base_url: str | None = None,
@@ -951,6 +967,7 @@ async def create(
951967
lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
952968
lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
953969
lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
970+
train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
954971
suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
955972
Defaults to None.
956973
wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -1028,6 +1045,7 @@ async def create(
10281045
lora_dropout=lora_dropout,
10291046
lora_alpha=lora_alpha,
10301047
lora_trainable_modules=lora_trainable_modules,
1048+
train_vision=train_vision,
10311049
suffix=suffix,
10321050
wandb_api_key=wandb_api_key,
10331051
wandb_base_url=wandb_base_url,
@@ -1046,7 +1064,11 @@ async def create(
10461064
hf_output_repo_name=hf_output_repo_name,
10471065
)
10481066

1049-
if from_checkpoint is None and from_hf_model is None:
1067+
if (
1068+
from_checkpoint is None
1069+
and from_hf_model is None
1070+
and not model_limits.supports_vision
1071+
):
10501072
price_estimation_result = await self.estimate_price(
10511073
training_file=training_file,
10521074
validation_file=validation_file,

src/together/types/__init__.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,18 @@
77
AudioSpeechStreamChunk,
88
AudioSpeechStreamEvent,
99
AudioSpeechStreamResponse,
10+
AudioTimestampGranularities,
1011
AudioTranscriptionRequest,
11-
AudioTranslationRequest,
1212
AudioTranscriptionResponse,
13+
AudioTranscriptionResponseFormat,
1314
AudioTranscriptionVerboseResponse,
15+
AudioTranslationRequest,
1416
AudioTranslationResponse,
1517
AudioTranslationVerboseResponse,
16-
AudioTranscriptionResponseFormat,
17-
AudioTimestampGranularities,
1818
ModelVoices,
1919
VoiceListResponse,
2020
)
21+
from together.types.batch import BatchEndpoint, BatchJob, BatchJobStatus
2122
from together.types.chat_completions import (
2223
ChatCompletionChunk,
2324
ChatCompletionRequest,
@@ -31,6 +32,19 @@
3132
)
3233
from together.types.embeddings import EmbeddingRequest, EmbeddingResponse
3334
from together.types.endpoints import Autoscaling, DedicatedEndpoint, ListEndpoint
35+
from together.types.evaluation import (
36+
ClassifyParameters,
37+
CompareParameters,
38+
EvaluationCreateResponse,
39+
EvaluationJob,
40+
EvaluationRequest,
41+
EvaluationStatus,
42+
EvaluationStatusResponse,
43+
EvaluationType,
44+
JudgeModelConfig,
45+
ModelRequest,
46+
ScoreParameters,
47+
)
3448
from together.types.files import (
3549
FileDeleteResponse,
3650
FileList,
@@ -41,49 +55,32 @@
4155
FileType,
4256
)
4357
from together.types.finetune import (
44-
TrainingMethodDPO,
45-
TrainingMethodSFT,
46-
FinetuneCheckpoint,
4758
CosineLRScheduler,
4859
CosineLRSchedulerArgs,
60+
FinetuneCheckpoint,
61+
FinetuneDeleteResponse,
4962
FinetuneDownloadResult,
50-
LinearLRScheduler,
51-
LinearLRSchedulerArgs,
52-
FinetuneLRScheduler,
5363
FinetuneList,
5464
FinetuneListEvents,
55-
FinetuneRequest,
56-
FinetuneResponse,
65+
FinetuneLRScheduler,
66+
FinetuneMultimodalParams,
5767
FinetunePriceEstimationRequest,
5868
FinetunePriceEstimationResponse,
59-
FinetuneDeleteResponse,
69+
FinetuneRequest,
70+
FinetuneResponse,
6071
FinetuneTrainingLimits,
6172
FullTrainingType,
73+
LinearLRScheduler,
74+
LinearLRSchedulerArgs,
6275
LoRATrainingType,
76+
TrainingMethodDPO,
77+
TrainingMethodSFT,
6378
TrainingType,
6479
)
6580
from together.types.images import ImageRequest, ImageResponse
6681
from together.types.models import ModelObject, ModelUploadRequest, ModelUploadResponse
6782
from together.types.rerank import RerankRequest, RerankResponse
68-
from together.types.batch import BatchJob, BatchJobStatus, BatchEndpoint
69-
from together.types.evaluation import (
70-
EvaluationType,
71-
EvaluationStatus,
72-
JudgeModelConfig,
73-
ModelRequest,
74-
ClassifyParameters,
75-
ScoreParameters,
76-
CompareParameters,
77-
EvaluationRequest,
78-
EvaluationCreateResponse,
79-
EvaluationJob,
80-
EvaluationStatusResponse,
81-
)
82-
from together.types.videos import (
83-
CreateVideoBody,
84-
CreateVideoResponse,
85-
VideoJob,
86-
)
83+
from together.types.videos import CreateVideoBody, CreateVideoResponse, VideoJob
8784

8885

8986
__all__ = [
@@ -131,6 +128,7 @@
131128
"RerankRequest",
132129
"RerankResponse",
133130
"FinetuneTrainingLimits",
131+
"FinetuneMultimodalParams",
134132
"AudioSpeechRequest",
135133
"AudioResponseFormat",
136134
"AudioLanguage",

src/together/types/finetune.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import List, Literal, Any
4+
from typing import Any, List, Literal
55

66
from pydantic import Field, StrictBool, field_validator
77

88
from together.types.abstract import BaseModel
9-
from together.types.common import (
10-
ObjectType,
11-
)
9+
from together.types.common import ObjectType
1210

1311

1412
class FinetuneJobStatus(str, Enum):
@@ -175,6 +173,14 @@ class TrainingMethodDPO(TrainingMethod):
175173
simpo_gamma: float | None = None
176174

177175

176+
class FinetuneMultimodalParams(BaseModel):
177+
"""
178+
Multimodal parameters
179+
"""
180+
181+
train_vision: bool = False
182+
183+
178184
class FinetuneProgress(BaseModel):
179185
"""
180186
Fine-tune job progress
@@ -231,6 +237,8 @@ class FinetuneRequest(BaseModel):
231237
)
232238
# from step
233239
from_checkpoint: str | None = None
240+
# multimodal parameters
241+
multimodal_params: FinetuneMultimodalParams | None = None
234242
# hf related fields
235243
hf_api_token: str | None = None
236244
hf_output_repo_name: str | None = None
@@ -313,6 +321,8 @@ class FinetuneResponse(BaseModel):
313321
training_file_size: int | None = Field(None, alias="TrainingFileSize")
314322
train_on_inputs: StrictBool | Literal["auto"] | None = "auto"
315323
from_checkpoint: str | None = None
324+
# multimodal parameters
325+
multimodal_params: FinetuneMultimodalParams | None = None
316326

317327
progress: FinetuneProgress | None = None
318328

@@ -409,6 +419,7 @@ class FinetuneTrainingLimits(BaseModel):
409419
min_learning_rate: float
410420
full_training: FinetuneFullTrainingLimits | None = None
411421
lora_training: FinetuneLoraTrainingLimits | None = None
422+
supports_vision: bool = False
412423

413424

414425
class LinearLRSchedulerArgs(BaseModel):

0 commit comments

Comments
 (0)