diff --git a/roboflow/core/version.py b/roboflow/core/version.py index f6ca0fc7..7e63b103 100644 --- a/roboflow/core/version.py +++ b/roboflow/core/version.py @@ -290,12 +290,13 @@ def export(self, model_format=None): except json.JSONDecodeError: response.raise_for_status() - def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel: + def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel: """ Ask the Roboflow API to train a previously exported version's dataset. Args: speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`. + model_type: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending an invalid parameter in this argument. checkpoint: A string representing the checkpoint to use while training plot: Whether to plot the training results. Default is `False`. @@ -328,12 +329,17 @@ def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> Inferenc url = f"{API_URL}/{workspace}/{project}/{self.version}/train" data = {} + if speed: data["speed"] = speed if checkpoint: data["checkpoint"] = checkpoint + if model_type: + # API expects camelCase key + data["modelType"] = model_type + write_line("Reaching out to Roboflow to start training...") response = requests.post(url, json=data, params={"api_key": self.__api_key}) diff --git a/roboflow/roboflowpy.py b/roboflow/roboflowpy.py index 48d6c0d7..70cf6db9 100755 --- a/roboflow/roboflowpy.py +++ b/roboflow/roboflowpy.py @@ -19,6 +19,15 @@ def login(args): roboflow.login(force=args.force) +def train(args): + rf = roboflow.Roboflow() + workspace = rf.workspace(args.workspace) # handles None internally + project = workspace.project(args.project) + version = project.version(args.version_number) + model = version.train(model_type=args.model_type, checkpoint=args.checkpoint) + print(model) + + def _parse_url(url): regex = r"(?:https?://)?(?:universe|app)\.roboflow\.(?:com|one)/([^/]+)/([^/]+)(?:/dataset)?(?:/(\d+))?|([^/]+)/([^/]+)(?:/(\d+))?" # noqa: E501 match = re.match(regex, url) @@ -198,6 +207,7 @@ def _argparser(): subparsers = parser.add_subparsers(title="subcommands") _add_login_parser(subparsers) _add_download_parser(subparsers) + _add_train_parser(subparsers) _add_upload_parser(subparsers) _add_import_parser(subparsers) _add_infer_parser(subparsers) @@ -310,6 +320,37 @@ def _add_upload_parser(subparsers): upload_parser.set_defaults(func=upload_image) +def _add_train_parser(subparsers): + train_parser = subparsers.add_parser("train", help="Train a model for a dataset version") + train_parser.add_argument( + "-w", + dest="workspace", + help="specify a workspace url or id (will use default workspace if not specified)", + ) + train_parser.add_argument( + "-p", + dest="project", + help="project_id to train the model for", + ) + train_parser.add_argument( + "-v", + dest="version_number", + type=int, + help="version number to train", + ) + train_parser.add_argument( + "-t", + dest="model_type", + help="type of the model to train (e.g., rfdetr-nano, yolov8n)", + ) + train_parser.add_argument( + "--checkpoint", + dest="checkpoint", + help="checkpoint to resume training from", + ) + train_parser.set_defaults(func=train) + + def _add_import_parser(subparsers): import_parser = subparsers.add_parser("import", help="Import a dataset from a local folder") import_parser.add_argument(