Skip to content

Rodrigo/modeltype parameter #401

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,12 +290,13 @@ def export(self, model_format=None):
except json.JSONDecodeError:
response.raise_for_status()

def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
"""
Ask the Roboflow API to train a previously exported version's dataset.

Args:
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument.
checkpoint: A string representing the checkpoint to use while training
plot: Whether to plot the training results. Default is `False`.

Expand Down Expand Up @@ -328,12 +329,17 @@ def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> Inferenc
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"

data = {}

if speed:
data["speed"] = speed

if checkpoint:
data["checkpoint"] = checkpoint

if model_type:
# API expects camelCase key
data["modelType"] = model_type

write_line("Reaching out to Roboflow to start training...")

response = requests.post(url, json=data, params={"api_key": self.__api_key})
Expand Down
41 changes: 41 additions & 0 deletions roboflow/roboflowpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def login(args):
roboflow.login(force=args.force)


def train(args):
rf = roboflow.Roboflow()
workspace = rf.workspace(args.workspace) # handles None internally
project = workspace.project(args.project)
version = project.version(args.version_number)
model = version.train(model_type=args.model_type, checkpoint=args.checkpoint)
print(model)


def _parse_url(url):
regex = r"(?:https?://)?(?:universe|app)\.roboflow\.(?:com|one)/([^/]+)/([^/]+)(?:/dataset)?(?:/(\d+))?|([^/]+)/([^/]+)(?:/(\d+))?" # noqa: E501
match = re.match(regex, url)
Expand Down Expand Up @@ -198,6 +207,7 @@ def _argparser():
subparsers = parser.add_subparsers(title="subcommands")
_add_login_parser(subparsers)
_add_download_parser(subparsers)
_add_train_parser(subparsers)
_add_upload_parser(subparsers)
_add_import_parser(subparsers)
_add_infer_parser(subparsers)
Expand Down Expand Up @@ -310,6 +320,37 @@ def _add_upload_parser(subparsers):
upload_parser.set_defaults(func=upload_image)


def _add_train_parser(subparsers):
train_parser = subparsers.add_parser("train", help="Train a model for a dataset version")
train_parser.add_argument(
"-w",
dest="workspace",
help="specify a workspace url or id (will use default workspace if not specified)",
)
train_parser.add_argument(
"-p",
dest="project",
help="project_id to train the model for",
)
train_parser.add_argument(
"-v",
dest="version_number",
type=int,
help="version number to train",
)
train_parser.add_argument(
"-t",
dest="model_type",
help="type of the model to train (e.g., rfdetr-nano, yolov8n)",
)
train_parser.add_argument(
"--checkpoint",
dest="checkpoint",
help="checkpoint to resume training from",
)
train_parser.set_defaults(func=train)


def _add_import_parser(subparsers):
import_parser = subparsers.add_parser("import", help="Import a dataset from a local folder")
import_parser.add_argument(
Expand Down
Loading