Skip to content

Commit 6a6025a

Browse files
authored
Add support to search by annotation_job and annotation_job_id (#406)
1 parent 92bd1b9 commit 6a6025a

File tree

3 files changed

+124
-2
lines changed

3 files changed

+124
-2
lines changed

roboflow/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from roboflow.models import CLIPModel, GazeModel # noqa: F401
1616
from roboflow.util.general import write_line
1717

18-
__version__ = "1.2.4"
18+
__version__ = "1.2.5"
1919

2020

2121
def check_key(api_key, model, notebook, num_retries=0):

roboflow/core/project.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,9 @@ def search(
653653
batch: bool = False,
654654
batch_id: Optional[str] = None,
655655
fields: Optional[List[str]] = None,
656+
*,
657+
annotation_job: Optional[bool] = None,
658+
annotation_job_id: Optional[str] = None,
656659
):
657660
"""
658661
Search for images in a project.
@@ -667,6 +670,8 @@ def search(
667670
in_dataset (str): dataset that an image must be in
668671
batch (bool): whether the image must be in a batch
669672
batch_id (str): batch id that an image must be in
673+
annotation_job (bool): whether the image must be in an annotation job
674+
annotation_job_id (str): annotation job id that an image must be in
670675
fields (list): fields to return in results (default: ["id", "created", "name", "labels"])
671676
672677
Returns:
@@ -684,7 +689,7 @@ def search(
684689
if fields is None:
685690
fields = ["id", "created", "name", "labels"]
686691

687-
payload: Dict[str, Union[str, int, List[str]]] = {}
692+
payload: Dict[str, Union[str, int, bool, List[str]]] = {}
688693

689694
if like_image is not None:
690695
payload["like_image"] = like_image
@@ -713,6 +718,12 @@ def search(
713718
if batch_id is not None:
714719
payload["batch_id"] = batch_id
715720

721+
if annotation_job is not None:
722+
payload["annotation_job"] = annotation_job
723+
724+
if annotation_job_id is not None:
725+
payload["annotation_job_id"] = annotation_job_id
726+
716727
payload["fields"] = fields
717728

718729
data = requests.post(
@@ -734,6 +745,9 @@ def search_all(
734745
batch: bool = False,
735746
batch_id: Optional[str] = None,
736747
fields: Optional[List[str]] = None,
748+
*,
749+
annotation_job: Optional[bool] = None,
750+
annotation_job_id: Optional[str] = None,
737751
):
738752
"""
739753
Create a paginated list of search results for use in searching the images in a project.
@@ -748,6 +762,8 @@ def search_all(
748762
in_dataset (str): dataset that an image must be in
749763
batch (bool): whether the image must be in a batch
750764
batch_id (str): batch id that an image must be in
765+
annotation_job (bool): whether the image must be in an annotation job
766+
annotation_job_id (str): annotation job id that an image must be in
751767
fields (list): fields to return in results (default: ["id", "created", "name", "labels"])
752768
753769
Returns:
@@ -781,6 +797,8 @@ def search_all(
781797
batch=batch,
782798
batch_id=batch_id,
783799
fields=fields,
800+
annotation_job=annotation_job,
801+
annotation_job_id=annotation_job_id,
784802
)
785803

786804
yield data

tests/test_project.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,3 +667,107 @@ def capture_annotation_calls(annotation_path, **kwargs):
667667
finally:
668668
for mock in mocks.values():
669669
mock.stop()
670+
671+
def test_search_with_annotation_job_params(self):
672+
"""Test that annotation_job and annotation_job_id parameters are properly included in search requests"""
673+
# Test 1: Search with annotation_job=True
674+
expected_url = f"{API_URL}/{WORKSPACE_NAME}/{PROJECT_NAME}/search?api_key={ROBOFLOW_API_KEY}"
675+
mock_response = {
676+
"results": [
677+
{"id": "image1", "name": "test1.jpg", "created": 1616161616, "labels": ["person"]},
678+
{"id": "image2", "name": "test2.jpg", "created": 1616161617, "labels": ["car"]},
679+
]
680+
}
681+
682+
responses.add(
683+
responses.POST,
684+
expected_url,
685+
json=mock_response,
686+
status=200,
687+
match=[
688+
json_params_matcher(
689+
{
690+
"offset": 0,
691+
"limit": 100,
692+
"batch": False,
693+
"annotation_job": True,
694+
"fields": ["id", "created", "name", "labels"],
695+
}
696+
)
697+
],
698+
)
699+
700+
results = self.project.search(annotation_job=True)
701+
self.assertEqual(len(results), 2)
702+
self.assertEqual(results[0]["id"], "image1")
703+
704+
# Test 2: Search with annotation_job_id
705+
test_job_id = "job_123456"
706+
responses.add(
707+
responses.POST,
708+
expected_url,
709+
json=mock_response,
710+
status=200,
711+
match=[
712+
json_params_matcher(
713+
{
714+
"offset": 0,
715+
"limit": 100,
716+
"batch": False,
717+
"annotation_job_id": test_job_id,
718+
"fields": ["id", "created", "name", "labels"],
719+
}
720+
)
721+
],
722+
)
723+
724+
results = self.project.search(annotation_job_id=test_job_id)
725+
self.assertEqual(len(results), 2)
726+
727+
# Test 3: Search with both parameters
728+
responses.add(
729+
responses.POST,
730+
expected_url,
731+
json=mock_response,
732+
status=200,
733+
match=[
734+
json_params_matcher(
735+
{
736+
"offset": 0,
737+
"limit": 50,
738+
"batch": False,
739+
"annotation_job": False,
740+
"annotation_job_id": test_job_id,
741+
"prompt": "dog",
742+
"fields": ["id", "created", "name", "labels"],
743+
}
744+
)
745+
],
746+
)
747+
748+
results = self.project.search(prompt="dog", annotation_job=False, annotation_job_id=test_job_id, limit=50)
749+
self.assertEqual(len(results), 2)
750+
751+
# Test 4: Verify parameters are not included when None
752+
responses.add(
753+
responses.POST,
754+
expected_url,
755+
json=mock_response,
756+
status=200,
757+
match=[
758+
json_params_matcher(
759+
{
760+
"offset": 0,
761+
"limit": 100,
762+
"batch": False,
763+
"fields": ["id", "created", "name", "labels"],
764+
# annotation_job and annotation_job_id should NOT be in the payload
765+
}
766+
)
767+
],
768+
)
769+
770+
# This should pass because json_params_matcher only checks that the
771+
# specified keys match, it doesn't fail if additional keys are missing
772+
results = self.project.search()
773+
self.assertEqual(len(results), 2)

0 commit comments

Comments
 (0)