Skip to content

Commit 40e2241

Browse files
authored
Merge pull request #368 from roboflow/sb/add-batch-methods
Add methods for retrieving batches for a project
2 parents c129fea + ef3ae71 commit 40e2241

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

roboflow/core/project.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -886,3 +886,68 @@ def create_annotation_job(
886886
raise RuntimeError(f"Failed to create annotation job: {response.text}")
887887

888888
return response.json()
889+
890+
def get_batches(self) -> Dict:
891+
"""
892+
Get a list of all batches in the project.
893+
894+
Returns:
895+
Dict: A dictionary containing the list of batches
896+
897+
Example:
898+
>>> import roboflow
899+
900+
>>> rf = roboflow.Roboflow(api_key="YOUR_API_KEY")
901+
902+
>>> project = rf.workspace().project("PROJECT_ID")
903+
904+
>>> batches = project.get_batches()
905+
"""
906+
url = f"{API_URL}/{self.__workspace}/{self.__project_name}/batches?api_key={self.__api_key}"
907+
908+
response = requests.get(url)
909+
910+
if response.status_code != 200:
911+
try:
912+
error_data = response.json()
913+
if "error" in error_data:
914+
raise RuntimeError(error_data["error"])
915+
raise RuntimeError(response.text)
916+
except ValueError:
917+
raise RuntimeError(f"Failed to get batches: {response.text}")
918+
919+
return response.json()
920+
921+
def get_batch(self, batch_id: str) -> Dict:
922+
"""
923+
Get information for a specific batch in the project.
924+
925+
Args:
926+
batch_id (str): The ID of the batch to retrieve
927+
928+
Returns:
929+
Dict: A dictionary containing the batch details
930+
931+
Example:
932+
>>> import roboflow
933+
934+
>>> rf = roboflow.Roboflow(api_key="YOUR_API_KEY")
935+
936+
>>> project = rf.workspace().project("PROJECT_ID")
937+
938+
>>> batch = project.get_batch("batch123")
939+
"""
940+
url = f"{API_URL}/{self.__workspace}/{self.__project_name}/batches/{batch_id}?api_key={self.__api_key}"
941+
942+
response = requests.get(url)
943+
944+
if response.status_code != 200:
945+
try:
946+
error_data = response.json()
947+
if "error" in error_data:
948+
raise RuntimeError(error_data["error"])
949+
raise RuntimeError(response.text)
950+
except ValueError:
951+
raise RuntimeError(f"Failed to get batch {batch_id}: {response.text}")
952+
953+
return response.json()

tests/test_project.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,89 @@ def test_project_upload_dataset(self):
455455
for mock in mocks.values():
456456
mock.stop()
457457

458+
def test_get_batches_success(self):
459+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches?api_key={ROBOFLOW_API_KEY}"
460+
mock_response = {
461+
"batches": [
462+
{
463+
"name": "Uploaded on 11/22/22 at 1:39 pm",
464+
"numJobs": 2,
465+
"images": 115,
466+
"uploaded": {"_seconds": 1669146024, "_nanoseconds": 818000000},
467+
"id": "batch-1",
468+
},
469+
{
470+
"numJobs": 0,
471+
"images": 11,
472+
"uploaded": {"_seconds": 1669236873, "_nanoseconds": 47000000},
473+
"name": "Upload via API",
474+
"id": "batch-2",
475+
},
476+
]
477+
}
478+
479+
responses.add(responses.GET, expected_url, json=mock_response, status=200)
480+
481+
batches = self.project.get_batches()
482+
483+
self.assertIsInstance(batches, dict)
484+
self.assertIn("batches", batches)
485+
self.assertEqual(len(batches["batches"]), 2)
486+
self.assertEqual(batches["batches"][0]["id"], "batch-1")
487+
self.assertEqual(batches["batches"][0]["name"], "Uploaded on 11/22/22 at 1:39 pm")
488+
self.assertEqual(batches["batches"][0]["images"], 115)
489+
self.assertEqual(batches["batches"][0]["numJobs"], 2)
490+
self.assertEqual(batches["batches"][1]["id"], "batch-2")
491+
self.assertEqual(batches["batches"][1]["name"], "Upload via API")
492+
493+
def test_get_batches_error(self):
494+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches?api_key={ROBOFLOW_API_KEY}"
495+
error_response = {"error": "Cannot retrieve batches"}
496+
497+
responses.add(responses.GET, expected_url, json=error_response, status=404)
498+
499+
with self.assertRaises(RuntimeError) as context:
500+
self.project.get_batches()
501+
502+
self.assertEqual(str(context.exception), "Cannot retrieve batches")
503+
504+
def test_get_batch_success(self):
505+
batch_id = "batch-123"
506+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches/{batch_id}?api_key={ROBOFLOW_API_KEY}"
507+
mock_response = {
508+
"batch": {
509+
"name": "Uploaded on 11/22/22 at 1:39 pm",
510+
"numJobs": 2,
511+
"images": 115,
512+
"uploaded": {"_seconds": 1669146024, "_nanoseconds": 818000000},
513+
"id": batch_id,
514+
}
515+
}
516+
517+
responses.add(responses.GET, expected_url, json=mock_response, status=200)
518+
519+
batch = self.project.get_batch(batch_id)
520+
521+
self.assertIsInstance(batch, dict)
522+
self.assertIn("batch", batch)
523+
self.assertEqual(batch["batch"]["id"], batch_id)
524+
self.assertEqual(batch["batch"]["name"], "Uploaded on 11/22/22 at 1:39 pm")
525+
self.assertEqual(batch["batch"]["images"], 115)
526+
self.assertEqual(batch["batch"]["numJobs"], 2)
527+
self.assertIn("uploaded", batch["batch"])
528+
529+
def test_get_batch_error(self):
530+
batch_id = "nonexistent-batch"
531+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches/{batch_id}?api_key={ROBOFLOW_API_KEY}"
532+
error_response = {"error": "Batch not found"}
533+
534+
responses.add(responses.GET, expected_url, json=error_response, status=404)
535+
536+
with self.assertRaises(RuntimeError) as context:
537+
self.project.get_batch(batch_id)
538+
539+
self.assertEqual(str(context.exception), "Batch not found")
540+
458541
def test_classification_dataset_upload(self):
459542
from roboflow.util import folderparser
460543

0 commit comments

Comments
 (0)