Skip to content

Commit a13046c

Browse files
authored
Merge pull request #402 from roboflow/cursor/address-pr-comments-and-fix-1a1d
Address pr comments and fix
2 parents 10b93d7 + 847ced2 commit a13046c

File tree

3 files changed

+11
-88
lines changed

3 files changed

+11
-88
lines changed

roboflow/core/version.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,13 @@ def export(self, model_format=None):
290290
except json.JSONDecodeError:
291291
response.raise_for_status()
292292

293-
def train(self, speed=None, modelType=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
293+
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
294294
"""
295295
Ask the Roboflow API to train a previously exported version's dataset.
296296
297297
Args:
298298
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
299-
modelType: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending a invalid parameter in this argument.
299+
model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument.
300300
checkpoint: A string representing the checkpoint to use while training
301301
plot: Whether to plot the training results. Default is `False`.
302302
@@ -329,14 +329,16 @@ def train(self, speed=None, modelType=None, checkpoint=None, plot_in_notebook=Fa
329329
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"
330330

331331
data = {}
332-
if modelType:
333-
data["modelType"] = modelType
332+
# Keep existing order to avoid unnecessary diffs; append new field without reordering
333+
if speed:
334+
data["speed"] = speed
334335

335336
if checkpoint:
336337
data["checkpoint"] = checkpoint
337338

338-
if speed:
339-
data["speed"] = speed
339+
if model_type:
340+
# API expects camelCase key
341+
data["modelType"] = model_type
340342

341343
write_line("Reaching out to Roboflow to start training...")
342344

roboflow/roboflowpy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@ def login(args):
2121

2222
def train(args):
2323
rf = roboflow.Roboflow()
24-
workspace = rf.workspace(args.workspace) if args.workspace else rf.workspace()
24+
workspace = rf.workspace(args.workspace) # handles None internally
2525
project = workspace.project(args.project)
2626
version = project.version(args.version_number)
27-
model = version.train(modelType=args.modelType, checkpoint=args.checkpoint)
27+
model = version.train(model_type=args.model_type, checkpoint=args.checkpoint)
2828
print(model)
2929

3030

@@ -340,7 +340,7 @@ def _add_train_parser(subparsers):
340340
)
341341
train_parser.add_argument(
342342
"-t",
343-
dest="modelType",
343+
dest="model_type",
344344
help="type of the model to train (e.g., rfdetr-nano, yolov8n)",
345345
)
346346
train_parser.add_argument(

tests/manual/debugme.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

0 commit comments

Comments
 (0)