22
33import re
44from pathlib import Path
5- from typing import List , Dict , Literal
5+ from typing import Dict , List , Literal
66
77from rich import print as rprint
88
1818 FinetuneList ,
1919 FinetuneListEvents ,
2020 FinetuneLRScheduler ,
21- FinetuneRequest ,
22- FinetuneResponse ,
21+ FinetuneMultimodalParams ,
2322 FinetunePriceEstimationRequest ,
2423 FinetunePriceEstimationResponse ,
24+ FinetuneRequest ,
25+ FinetuneResponse ,
2526 FinetuneTrainingLimits ,
2627 FullTrainingType ,
2728 LinearLRScheduler ,
@@ -73,6 +74,7 @@ def create_finetune_request(
7374 lora_dropout : float | None = 0 ,
7475 lora_alpha : float | None = None ,
7576 lora_trainable_modules : str | None = "all-linear" ,
77+ train_vision : bool = False ,
7678 suffix : str | None = None ,
7779 wandb_api_key : str | None = None ,
7880 wandb_base_url : str | None = None ,
@@ -252,6 +254,15 @@ def create_finetune_request(
252254 simpo_gamma = simpo_gamma ,
253255 )
254256
257+ if model_limits .supports_vision :
258+ multimodal_params = FinetuneMultimodalParams (train_vision = train_vision )
259+ elif not model_limits .supports_vision and train_vision :
260+ raise ValueError (
261+ f"Vision encoder training is not supported for the non-multimodal model `{ model } `"
262+ )
263+ else :
264+ multimodal_params = None
265+
255266 finetune_request = FinetuneRequest (
256267 model = model ,
257268 training_file = training_file ,
@@ -272,6 +283,7 @@ def create_finetune_request(
272283 wandb_project_name = wandb_project_name ,
273284 wandb_name = wandb_name ,
274285 training_method = training_method_cls ,
286+ multimodal_params = multimodal_params ,
275287 from_checkpoint = from_checkpoint ,
276288 from_hf_model = from_hf_model ,
277289 hf_model_revision = hf_model_revision ,
@@ -342,6 +354,7 @@ def create(
342354 lora_dropout : float | None = 0 ,
343355 lora_alpha : float | None = None ,
344356 lora_trainable_modules : str | None = "all-linear" ,
357+ train_vision : bool = False ,
345358 suffix : str | None = None ,
346359 wandb_api_key : str | None = None ,
347360 wandb_base_url : str | None = None ,
@@ -387,6 +400,7 @@ def create(
387400 lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
388401 lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
389402 lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
403+ train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
390404 suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
391405 Defaults to None.
392406 wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -464,6 +478,7 @@ def create(
464478 lora_dropout = lora_dropout ,
465479 lora_alpha = lora_alpha ,
466480 lora_trainable_modules = lora_trainable_modules ,
481+ train_vision = train_vision ,
467482 suffix = suffix ,
468483 wandb_api_key = wandb_api_key ,
469484 wandb_base_url = wandb_base_url ,
@@ -906,6 +921,7 @@ async def create(
906921 lora_dropout : float | None = 0 ,
907922 lora_alpha : float | None = None ,
908923 lora_trainable_modules : str | None = "all-linear" ,
924+ train_vision : bool = False ,
909925 suffix : str | None = None ,
910926 wandb_api_key : str | None = None ,
911927 wandb_base_url : str | None = None ,
@@ -951,6 +967,7 @@ async def create(
951967 lora_dropout (float, optional): Dropout rate for LoRA adapters. Defaults to 0.
952968 lora_alpha (float, optional): Alpha for LoRA adapters. Defaults to 8.
953969 lora_trainable_modules (str, optional): Trainable modules for LoRA adapters. Defaults to "all-linear".
970+ train_vision (bool, optional): Whether to train vision encoder in multimodal models. Defaults to False.
954971 suffix (str, optional): Up to 40 character suffix that will be added to your fine-tuned model name.
955972 Defaults to None.
956973 wandb_api_key (str, optional): API key for Weights & Biases integration.
@@ -1028,6 +1045,7 @@ async def create(
10281045 lora_dropout = lora_dropout ,
10291046 lora_alpha = lora_alpha ,
10301047 lora_trainable_modules = lora_trainable_modules ,
1048+ train_vision = train_vision ,
10311049 suffix = suffix ,
10321050 wandb_api_key = wandb_api_key ,
10331051 wandb_base_url = wandb_base_url ,
@@ -1046,7 +1064,11 @@ async def create(
10461064 hf_output_repo_name = hf_output_repo_name ,
10471065 )
10481066
1049- if from_checkpoint is None and from_hf_model is None :
1067+ if (
1068+ from_checkpoint is None
1069+ and from_hf_model is None
1070+ and not model_limits .supports_vision
1071+ ):
10501072 price_estimation_result = await self .estimate_price (
10511073 training_file = training_file ,
10521074 validation_file = validation_file ,
0 commit comments