Skip to content

Commit f927635

Browse files
committed
Add model_type parameter support for training
- Add model_type parameter to CLI train command - Update version.py to handle model_type in train method - Snake_case naming convention for model_type parameter - Simplify CLI workspace handling - Maintain API payload order minimal
1 parent 36a55fc commit f927635

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

roboflow/core/version.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,13 @@ def export(self, model_format=None):
290290
except json.JSONDecodeError:
291291
response.raise_for_status()
292292

293-
def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
293+
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
294294
"""
295295
Ask the Roboflow API to train a previously exported version's dataset.
296296
297297
Args:
298298
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
299+
model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument.
299300
checkpoint: A string representing the checkpoint to use while training
300301
plot: Whether to plot the training results. Default is `False`.
301302
@@ -328,12 +329,17 @@ def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> Inferenc
328329
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"
329330

330331
data = {}
332+
331333
if speed:
332334
data["speed"] = speed
333335

334336
if checkpoint:
335337
data["checkpoint"] = checkpoint
336338

339+
if model_type:
340+
# API expects camelCase key
341+
data["modelType"] = model_type
342+
337343
write_line("Reaching out to Roboflow to start training...")
338344

339345
response = requests.post(url, json=data, params={"api_key": self.__api_key})

roboflow/roboflowpy.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def login(args):
1919
roboflow.login(force=args.force)
2020

2121

22+
def train(args):
23+
rf = roboflow.Roboflow()
24+
workspace = rf.workspace(args.workspace) # handles None internally
25+
project = workspace.project(args.project)
26+
version = project.version(args.version_number)
27+
model = version.train(model_type=args.model_type, checkpoint=args.checkpoint)
28+
print(model)
29+
30+
2231
def _parse_url(url):
2332
regex = r"(?:https?://)?(?:universe|app)\.roboflow\.(?:com|one)/([^/]+)/([^/]+)(?:/dataset)?(?:/(\d+))?|([^/]+)/([^/]+)(?:/(\d+))?" # noqa: E501
2433
match = re.match(regex, url)
@@ -198,6 +207,7 @@ def _argparser():
198207
subparsers = parser.add_subparsers(title="subcommands")
199208
_add_login_parser(subparsers)
200209
_add_download_parser(subparsers)
210+
_add_train_parser(subparsers)
201211
_add_upload_parser(subparsers)
202212
_add_import_parser(subparsers)
203213
_add_infer_parser(subparsers)
@@ -310,6 +320,37 @@ def _add_upload_parser(subparsers):
310320
upload_parser.set_defaults(func=upload_image)
311321

312322

323+
def _add_train_parser(subparsers):
324+
train_parser = subparsers.add_parser("train", help="Train a model for a dataset version")
325+
train_parser.add_argument(
326+
"-w",
327+
dest="workspace",
328+
help="specify a workspace url or id (will use default workspace if not specified)",
329+
)
330+
train_parser.add_argument(
331+
"-p",
332+
dest="project",
333+
help="project_id to train the model for",
334+
)
335+
train_parser.add_argument(
336+
"-v",
337+
dest="version_number",
338+
type=int,
339+
help="version number to train",
340+
)
341+
train_parser.add_argument(
342+
"-t",
343+
dest="model_type",
344+
help="type of the model to train (e.g., rfdetr-nano, yolov8n)",
345+
)
346+
train_parser.add_argument(
347+
"--checkpoint",
348+
dest="checkpoint",
349+
help="checkpoint to resume training from",
350+
)
351+
train_parser.set_defaults(func=train)
352+
353+
313354
def _add_import_parser(subparsers):
314355
import_parser = subparsers.add_parser("import", help="Import a dataset from a local folder")
315356
import_parser.add_argument(

tests/manual/debugme.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
os.environ["ROBOFLOW_CONFIG_DIR"] = f"{thisdir}/data/.config"
66

77
from roboflow.roboflowpy import _argparser # noqa: E402
8+
from roboflow import Roboflow
89

910
# import requests
1011
# requests.urllib3.disable_warnings()
1112

1213
rootdir = os.path.abspath(f"{thisdir}/../..")
1314
sys.path.append(rootdir)
1415

15-
if __name__ == "__main__":
16+
17+
def run_cli():
1618
parser = _argparser()
1719
# args = parser.parse_args(["login"])
1820
# args = parser.parse_args(f"upload {thisdir}/../datasets/chess -w wolfodorpythontests -p chess".split()) # noqa: E501 // docs
@@ -45,3 +47,33 @@
4547
# f"import -w tonyprivate -p meh-plvrv {thisdir}/../datasets/paligemma/".split() # noqa: E501 // docs
4648
)
4749
args.func(args)
50+
51+
52+
def run_api_train():
53+
rf = Roboflow()
54+
project = rf.workspace("model-evaluation-workspace").project("penguin-finder")
55+
# version_number = project.generate_version(
56+
# settings={
57+
# "augmentation": {
58+
# "bbblur": {"pixels": 1.5},
59+
# "image": {"versions": 2},
60+
# },
61+
# "preprocessing": {
62+
# "auto-orient": True,
63+
# },
64+
# }
65+
# )
66+
version_number = "18"
67+
print(version_number)
68+
version = project.version(version_number)
69+
model = version.train(
70+
speed="fast", # Options: "fast" (default) or "accurate" (paid feature)
71+
checkpoint=None, # Use a specific checkpoint to continue training
72+
modelType="rfdetr-nano",
73+
)
74+
print(model)
75+
76+
77+
if __name__ == "__main__":
78+
# run_cli()
79+
run_api_train()

0 commit comments

Comments
 (0)