Skip to content

Commit f561f05

Browse files
Fix train-over-api devx (#408)
* Fix train-over-api devx * fix(pre_commit): 🎨 auto format pre-commit hooks * lint gods, please forgive me * all hail the linting gods * fix version tests * fix(pre_commit): 🎨 auto format pre-commit hooks --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6a6025a commit f561f05

File tree

4 files changed

+215
-134
lines changed

4 files changed

+215
-134
lines changed

roboflow/adapters/rfapi.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,106 @@ def get_project(api_key, workspace_url, project_url):
4949
return result
5050

5151

52+
def start_version_training(
53+
api_key: str,
54+
workspace_url: str,
55+
project_url: str,
56+
version: str,
57+
*,
58+
speed: Optional[str] = None,
59+
checkpoint: Optional[str] = None,
60+
model_type: Optional[str] = None,
61+
):
62+
"""
63+
Start a training job for a specific version.
64+
65+
This is a thin plumbing wrapper around the backend endpoint.
66+
"""
67+
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/train?api_key={api_key}&nocache=true"
68+
69+
data = {}
70+
if speed is not None:
71+
data["speed"] = speed
72+
if checkpoint is not None:
73+
data["checkpoint"] = checkpoint
74+
if model_type is not None:
75+
# API expects camelCase
76+
data["modelType"] = model_type
77+
78+
response = requests.post(url, json=data)
79+
if not response.ok:
80+
raise RoboflowError(response.text)
81+
return True
82+
83+
84+
def get_version(api_key: str, workspace_url: str, project_url: str, version: str, nocache: bool = False):
85+
"""
86+
Fetch detailed information about a specific dataset version.
87+
88+
Args:
89+
api_key: Roboflow API key
90+
workspace_url: Workspace slug/url
91+
project_url: Project slug/url
92+
version: Version identifier (number or slug)
93+
nocache: If True, bypass server-side cache
94+
95+
Returns:
96+
Parsed JSON response from the API.
97+
98+
Raises:
99+
RoboflowError: On non-200 response status codes.
100+
"""
101+
url = f"{API_URL}/{workspace_url}/{project_url}/{version}?api_key={api_key}"
102+
if nocache:
103+
url += "&nocache=true"
104+
105+
response = requests.get(url)
106+
if response.status_code != 200:
107+
raise RoboflowError(response.text)
108+
return response.json()
109+
110+
111+
def get_version_export(
112+
api_key: str,
113+
workspace_url: str,
114+
project_url: str,
115+
version: str,
116+
format: str,
117+
):
118+
"""
119+
Fetch export status or finalized link for a specific version/format.
120+
121+
Returns either:
122+
- {"ready": False, "progress": float} when the export is in progress (HTTP 202)
123+
- The raw JSON payload (dict) from the server when the export is ready (HTTP 200)
124+
125+
Raises RoboflowError on non-200/202 statuses or invalid/missing JSON when 200/202.
126+
"""
127+
url = f"{API_URL}/{workspace_url}/{project_url}/{version}/{format}?api_key={api_key}&nocache=true"
128+
response = requests.get(url)
129+
130+
# Non-success codes other than 202 are errors
131+
if response.status_code not in (200, 202):
132+
raise RoboflowError(response.text)
133+
134+
try:
135+
payload = response.json()
136+
except Exception:
137+
# If server returns a 200/202 without JSON, treat as error for consumers
138+
raise RoboflowError(str(response))
139+
140+
if response.status_code == 202:
141+
progress = payload.get("progress")
142+
try:
143+
progress_val = float(progress) if progress is not None else 0.0
144+
except Exception:
145+
progress_val = 0.0
146+
return {"ready": False, "progress": progress_val}
147+
148+
# 200 OK: export is ready; return payload unchanged
149+
return payload
150+
151+
52152
def upload_image(
53153
api_key,
54154
project_url,

roboflow/core/version.py

Lines changed: 77 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from dotenv import load_dotenv
1313
from tqdm import tqdm
1414

15+
from roboflow.adapters import rfapi
1516
from roboflow.config import (
1617
API_URL,
1718
APP_URL,
@@ -92,11 +93,11 @@ def __init__(
9293

9394
version_without_workspace = os.path.basename(str(version))
9495

95-
response = requests.get(f"{API_URL}/{workspace}/{project}/{self.version}?api_key={self.__api_key}")
96-
if response.ok:
97-
version_info = response.json()["version"]
96+
try:
97+
version_response = rfapi.get_version(self.__api_key, workspace, project, self.version)
98+
version_info = version_response.get("version", {})
9899
has_model = bool(version_info.get("train", {}).get("model"))
99-
else:
100+
except rfapi.RoboflowError:
100101
has_model = False
101102

102103
if not has_model:
@@ -152,16 +153,17 @@ def __init__(
152153

153154
def __check_if_generating(self):
154155
# check Roboflow API to see if this version is still generating
155-
156-
url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true"
157-
response = requests.get(url, params={"api_key": self.__api_key})
158-
response.raise_for_status()
159-
if response.json()["version"]["progress"] is None:
160-
progress = 0.0
161-
else:
162-
progress = float(response.json()["version"]["progress"])
163-
164-
return response.json()["version"]["generating"], progress
156+
versiondict = rfapi.get_version(
157+
api_key=self.__api_key,
158+
workspace_url=self.workspace,
159+
project_url=self.project,
160+
version=self.version,
161+
nocache=True,
162+
)
163+
version_obj = versiondict.get("version", {})
164+
progress = 0.0 if version_obj.get("progress") is None else float(version_obj.get("progress"))
165+
generating = bool(version_obj.get("generating") or version_obj.get("images", 0) == 0)
166+
return generating, progress
165167

166168
def __wait_if_generating(self, recurse=False):
167169
# checks if a given version is still in the progress of generating
@@ -219,15 +221,22 @@ def download(self, model_format=None, location=None, overwrite: bool = False):
219221
if self.__api_key == "coco-128-sample":
220222
link = "https://app.roboflow.com/ds/n9QwXwUK42?key=NnVCe2yMxP"
221223
else:
222-
url = self.__get_download_url(model_format)
223-
response = requests.get(url, params={"api_key": self.__api_key})
224-
if response.status_code == 200:
225-
link = response.json()["export"]["link"]
226-
else:
227-
try:
228-
raise RuntimeError(response.json())
229-
except json.JSONDecodeError:
230-
response.raise_for_status()
224+
workspace, project, *_ = self.id.rsplit("/")
225+
try:
226+
export_info = rfapi.get_version_export(
227+
api_key=self.__api_key,
228+
workspace_url=workspace,
229+
project_url=project,
230+
version=self.version,
231+
format=model_format,
232+
)
233+
except rfapi.RoboflowError as e:
234+
raise RuntimeError(str(e))
235+
236+
if "ready" in export_info and export_info.get("ready") is False:
237+
raise RuntimeError(export_info)
238+
239+
link = export_info["export"]["link"]
231240

232241
self.__download_zip(link, location, model_format)
233242
self.__extract_zip(location, model_format)
@@ -256,39 +265,36 @@ def export(self, model_format=None):
256265

257266
self.__wait_if_generating()
258267

259-
url = self.__get_download_url(model_format)
260-
response = requests.get(url, params={"api_key": self.__api_key})
261-
if not response.ok:
262-
try:
263-
raise RuntimeError(response.json())
264-
except json.JSONDecodeError:
265-
response.raise_for_status()
266-
267-
# the rest api returns 202 if the export is still in progress
268-
if response.status_code == 202:
269-
status_code_check = 202
270-
while status_code_check == 202:
271-
time.sleep(1)
272-
response = requests.get(url, params={"api_key": self.__api_key})
273-
status_code_check = response.status_code
274-
if status_code_check == 202:
275-
progress = response.json()["progress"]
276-
progress_message = (
277-
"Exporting format " + model_format + " in progress : " + str(round(progress * 100, 2)) + "%"
278-
)
279-
sys.stdout.write("\r" + progress_message)
280-
sys.stdout.flush()
281-
282-
if response.status_code == 200:
268+
workspace, project, *_ = self.id.rsplit("/")
269+
export_info = rfapi.get_version_export(
270+
api_key=self.__api_key,
271+
workspace_url=workspace,
272+
project_url=project,
273+
version=self.version,
274+
format=model_format,
275+
)
276+
while "ready" in export_info and export_info.get("ready") is False:
277+
progress = export_info.get("progress", 0.0)
278+
progress_message = (
279+
"Exporting format " + model_format + " in progress : " + str(round(progress * 100, 2)) + "%"
280+
)
281+
sys.stdout.write("\r" + progress_message)
282+
sys.stdout.flush()
283+
time.sleep(1)
284+
export_info = rfapi.get_version_export(
285+
api_key=self.__api_key,
286+
workspace_url=workspace,
287+
project_url=project,
288+
version=self.version,
289+
format=model_format,
290+
)
291+
if "export" in export_info:
283292
sys.stdout.write("\n")
284293
print("\r" + "Version export complete for " + model_format + " format")
285294
sys.stdout.flush()
286295
return True
287296
else:
288-
try:
289-
raise RuntimeError(response.json())
290-
except json.JSONDecodeError:
291-
response.raise_for_status()
297+
raise RuntimeError(f"Unexpected export {export_info}")
292298

293299
def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=False) -> InferenceModel:
294300
"""
@@ -326,28 +332,22 @@ def train(self, speed=None, model_type=None, checkpoint=None, plot_in_notebook=F
326332
self.export(train_model_format)
327333

328334
workspace, project, *_ = self.id.rsplit("/")
329-
url = f"{API_URL}/{workspace}/{project}/{self.version}/train"
330335

331-
data = {}
332-
333-
if speed:
334-
data["speed"] = speed
335-
336-
if checkpoint:
337-
data["checkpoint"] = checkpoint
338-
339-
if model_type:
340-
# API expects camelCase key
341-
data["modelType"] = model_type
336+
payload_speed = speed if speed else None
337+
payload_checkpoint = checkpoint if checkpoint else None
338+
payload_model_type = model_type if model_type else None
342339

343340
write_line("Reaching out to Roboflow to start training...")
344341

345-
response = requests.post(url, json=data, params={"api_key": self.__api_key})
346-
if not response.ok:
347-
try:
348-
raise RuntimeError(response.json())
349-
except json.JSONDecodeError:
350-
response.raise_for_status()
342+
rfapi.start_version_training(
343+
api_key=self.__api_key,
344+
workspace_url=workspace,
345+
project_url=project,
346+
version=self.version,
347+
speed=payload_speed,
348+
checkpoint=payload_checkpoint,
349+
model_type=payload_model_type,
350+
)
351351

352352
status = "training"
353353

@@ -374,10 +374,14 @@ def live_plot(epochs, mAP, loss, title=""):
374374
num_machine_spin_dots = []
375375

376376
while status == "training" or status == "running":
377-
url = f"{API_URL}/{self.workspace}/{self.project}/{self.version}?nocache=true"
378-
response = requests.get(url, params={"api_key": self.__api_key})
379-
response.raise_for_status()
380-
version = response.json()["version"]
377+
version_response = rfapi.get_version(
378+
api_key=self.__api_key,
379+
workspace_url=self.workspace,
380+
project_url=self.project,
381+
version=self.version,
382+
nocache=True,
383+
)
384+
version = version_response.get("version", {})
381385
if "models" in version.keys():
382386
models = version["models"]
383387
else:

tests/manual/debugme.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,18 +52,18 @@ def run_cli():
5252
def run_api_train():
5353
rf = Roboflow()
5454
project = rf.workspace("meh3").project("mosquitobao")
55-
# version_number = project.generate_version(
56-
# settings={
57-
# "augmentation": {
58-
# "bbblur": {"pixels": 1.5},
59-
# "image": {"versions": 2},
60-
# },
61-
# "preprocessing": {
62-
# "auto-orient": True,
63-
# },
64-
# }
65-
# )
66-
version_number = "61"
55+
version_number = project.generate_version(
56+
settings={
57+
"augmentation": {
58+
"bbblur": {"pixels": 1.5},
59+
"image": {"versions": 2},
60+
},
61+
"preprocessing": {
62+
"auto-orient": True,
63+
},
64+
}
65+
)
66+
# version_number = "61"
6767
print(version_number)
6868
version = project.version(version_number)
6969
model = version.train(

0 commit comments

Comments
 (0)