Skip to content

Commit 1d11997

Browse files
committed
Train with modelType parameter
1 parent 4d185a5 commit 1d11997

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

roboflow/core/version.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,13 @@ def export(self, model_format=None):
290290
except json.JSONDecodeError:
291291
response.raise_for_status()
292292

293-
def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
293+
def train(self, speed=None, modelType=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
294294
"""
295295
Ask the Roboflow API to train a previously exported version's dataset.
296296
297297
Args:
298298
speed: Whether to train quickly or accurately. Note: accurate training is a paid feature. Default speed is `fast`.
299+
modelType: The type of model to train. Default depends on kind of project. It takes precedence over speed. You can check the list of model ids by sending a invalid parameter in this argument.
299300
checkpoint: A string representing the checkpoint to use while training
300301
plot: Whether to plot the training results. Default is `False`.
301302
@@ -328,12 +329,15 @@ def train(self, speed=None, checkpoint=None, plot_in_notebook=False) -> Inferenc
328329
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"
329330

330331
data = {}
331-
if speed:
332-
data["speed"] = speed
332+
if modelType:
333+
data["modelType"] = modelType
333334

334335
if checkpoint:
335336
data["checkpoint"] = checkpoint
336337

338+
if speed:
339+
data["speed"] = speed
340+
337341
write_line("Reaching out to Roboflow to start training...")
338342

339343
response = requests.post(url, json=data, params={"api_key": self.__api_key})

tests/manual/debugme.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def run_cli():
5151

5252
def run_api_train():
5353
rf = Roboflow()
54-
project = rf.workspace("meh3").project("mosquitobao")
54+
project = rf.workspace("model-evaluation-workspace").project("donut-2-lcfx0")
5555
# version_number = project.generate_version(
5656
# settings={
5757
# "augmentation": {
@@ -63,7 +63,7 @@ def run_api_train():
6363
# },
6464
# }
6565
# )
66-
version_number = "61"
66+
version_number = "4"
6767
print(version_number)
6868
version = project.version(version_number)
6969
model = version.train(

0 commit comments

Comments
 (0)