Skip to content

Commit 17ea03f

Browse files
committed
roboflow CLI with train
1 parent 1d11997 commit 17ea03f

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

roboflow/roboflowpy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def login(args):
1919
roboflow.login(force=args.force)
2020

2121

22+
def train(args):
23+
rf = roboflow.Roboflow()
24+
workspace = rf.workspace(args.workspace) if args.workspace else rf.workspace()
25+
project = workspace.project(args.project)
26+
version = project.version(args.version_number)
27+
model = version.train(modelType=args.modelType, checkpoint=args.checkpoint)
28+
print(model)
29+
30+
2231
def _parse_url(url):
2332
regex = r"(?:https?://)?(?:universe|app)\.roboflow\.(?:com|one)/([^/]+)/([^/]+)(?:/dataset)?(?:/(\d+))?|([^/]+)/([^/]+)(?:/(\d+))?" # noqa: E501
2433
match = re.match(regex, url)
@@ -198,6 +207,7 @@ def _argparser():
198207
subparsers = parser.add_subparsers(title="subcommands")
199208
_add_login_parser(subparsers)
200209
_add_download_parser(subparsers)
210+
_add_train_parser(subparsers)
201211
_add_upload_parser(subparsers)
202212
_add_import_parser(subparsers)
203213
_add_infer_parser(subparsers)
@@ -310,6 +320,37 @@ def _add_upload_parser(subparsers):
310320
upload_parser.set_defaults(func=upload_image)
311321

312322

323+
def _add_train_parser(subparsers):
324+
train_parser = subparsers.add_parser("train", help="Train a model for a dataset version")
325+
train_parser.add_argument(
326+
"-w",
327+
dest="workspace",
328+
help="specify a workspace url or id (will use default workspace if not specified)",
329+
)
330+
train_parser.add_argument(
331+
"-p",
332+
dest="project",
333+
help="project_id to train the model for",
334+
)
335+
train_parser.add_argument(
336+
"-v",
337+
dest="version_number",
338+
type=int,
339+
help="version number to train",
340+
)
341+
train_parser.add_argument(
342+
"-t",
343+
dest="modelType",
344+
help="type of the model to train (e.g., rfdetr-nano, yolov8n)",
345+
)
346+
train_parser.add_argument(
347+
"--checkpoint",
348+
dest="checkpoint",
349+
help="checkpoint to resume training from",
350+
)
351+
train_parser.set_defaults(func=train)
352+
353+
313354
def _add_import_parser(subparsers):
314355
import_parser = subparsers.add_parser("import", help="Import a dataset from a local folder")
315356
import_parser.add_argument(

0 commit comments

Comments
 (0)