Skip to content

Commit d0bbaa4

Browse files
committed
fix it
1 parent f40a653 commit d0bbaa4

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

roboflow/core/project.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,3 +869,68 @@ def create_annotation_job(
869869
raise RuntimeError(f"Failed to create annotation job: {response.text}")
870870

871871
return response.json()
872+
873+
def get_batches(self) -> Dict:
874+
"""
875+
Get a list of all batches in the project.
876+
877+
Returns:
878+
Dict: A dictionary containing the list of batches
879+
880+
Example:
881+
>>> import roboflow
882+
883+
>>> rf = roboflow.Roboflow(api_key="YOUR_API_KEY")
884+
885+
>>> project = rf.workspace().project("PROJECT_ID")
886+
887+
>>> batches = project.get_batches()
888+
"""
889+
url = f"{API_URL}/{self.__workspace}/{self.__project_name}/batches?api_key={self.__api_key}"
890+
891+
response = requests.get(url)
892+
893+
if response.status_code != 200:
894+
try:
895+
error_data = response.json()
896+
if "error" in error_data:
897+
raise RuntimeError(error_data["error"])
898+
raise RuntimeError(response.text)
899+
except ValueError:
900+
raise RuntimeError(f"Failed to get batches: {response.text}")
901+
902+
return response.json()
903+
904+
def get_batch(self, batch_id: str) -> Dict:
905+
"""
906+
Get information for a specific batch in the project.
907+
908+
Args:
909+
batch_id (str): The ID of the batch to retrieve
910+
911+
Returns:
912+
Dict: A dictionary containing the batch details
913+
914+
Example:
915+
>>> import roboflow
916+
917+
>>> rf = roboflow.Roboflow(api_key="YOUR_API_KEY")
918+
919+
>>> project = rf.workspace().project("PROJECT_ID")
920+
921+
>>> batch = project.get_batch("batch123")
922+
"""
923+
url = f"{API_URL}/{self.__workspace}/{self.__project_name}/batches/{batch_id}?api_key={self.__api_key}"
924+
925+
response = requests.get(url)
926+
927+
if response.status_code != 200:
928+
try:
929+
error_data = response.json()
930+
if "error" in error_data:
931+
raise RuntimeError(error_data["error"])
932+
raise RuntimeError(response.text)
933+
except ValueError:
934+
raise RuntimeError(f"Failed to get batch {batch_id}: {response.text}")
935+
936+
return response.json()

tests/test_project.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,3 +454,82 @@ def test_project_upload_dataset(self):
454454
finally:
455455
for mock in mocks.values():
456456
mock.stop()
457+
458+
def test_get_batches_success(self):
459+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches?api_key={ROBOFLOW_API_KEY}"
460+
mock_response = {
461+
"batches": [
462+
{
463+
"id": "batch-1",
464+
"name": "Batch 1",
465+
"created": 1616161616,
466+
"images": 10,
467+
},
468+
{
469+
"id": "batch-2",
470+
"name": "Batch 2",
471+
"created": 1616161617,
472+
"images": 5,
473+
}
474+
]
475+
}
476+
477+
responses.add(responses.GET, expected_url, json=mock_response, status=200)
478+
479+
batches = self.project.get_batches()
480+
481+
self.assertIsInstance(batches, dict)
482+
self.assertIn("batches", batches)
483+
self.assertEqual(len(batches["batches"]), 2)
484+
self.assertEqual(batches["batches"][0]["id"], "batch-1")
485+
self.assertEqual(batches["batches"][1]["id"], "batch-2")
486+
487+
def test_get_batches_error(self):
488+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches?api_key={ROBOFLOW_API_KEY}"
489+
error_response = {"error": "Cannot retrieve batches"}
490+
491+
responses.add(responses.GET, expected_url, json=error_response, status=404)
492+
493+
with self.assertRaises(RuntimeError) as context:
494+
self.project.get_batches()
495+
496+
self.assertEqual(str(context.exception), "Cannot retrieve batches")
497+
498+
def test_get_batch_success(self):
499+
batch_id = "batch-123"
500+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches/{batch_id}?api_key={ROBOFLOW_API_KEY}"
501+
mock_response = {
502+
"batch": {
503+
"id": batch_id,
504+
"name": "My Test Batch",
505+
"created": 1616161616,
506+
"images": 25,
507+
"metadata": {
508+
"source": "API Upload",
509+
"type": "test"
510+
}
511+
}
512+
}
513+
514+
responses.add(responses.GET, expected_url, json=mock_response, status=200)
515+
516+
batch = self.project.get_batch(batch_id)
517+
518+
self.assertIsInstance(batch, dict)
519+
self.assertIn("batch", batch)
520+
self.assertEqual(batch["batch"]["id"], batch_id)
521+
self.assertEqual(batch["batch"]["name"], "My Test Batch")
522+
self.assertEqual(batch["batch"]["images"], 25)
523+
self.assertIn("metadata", batch["batch"])
524+
525+
def test_get_batch_error(self):
526+
batch_id = "nonexistent-batch"
527+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/batches/{batch_id}?api_key={ROBOFLOW_API_KEY}"
528+
error_response = {"error": "Batch not found"}
529+
530+
responses.add(responses.GET, expected_url, json=error_response, status=404)
531+
532+
with self.assertRaises(RuntimeError) as context:
533+
self.project.get_batch(batch_id)
534+
535+
self.assertEqual(str(context.exception), "Batch not found")

0 commit comments

Comments
 (0)