Skip to content

Commit 0459e0a

Browse files
Merge pull request #24 from tjmlabs/filter
Add filter method to the sdk
2 parents 0db6d65 + 173d2a7 commit 0459e0a

File tree

2 files changed

+255
-37
lines changed

2 files changed

+255
-37
lines changed

colivara_py/client.py

Lines changed: 79 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,16 @@
1-
import os
2-
import requests
3-
from typing import Optional, Dict, Any, List, Union
4-
from .models import (
5-
CollectionIn,
6-
CollectionOut,
7-
GenericError,
8-
GenericMessage,
9-
PatchCollectionIn,
10-
DocumentIn,
11-
DocumentOut,
12-
DocumentInPatch,
13-
QueryIn,
14-
QueryOut,
15-
QueryFilter,
16-
FileOut,
17-
EmbeddingsOut,
18-
TaskEnum,
19-
EmbeddingsIn,
20-
)
211
import base64
2+
import os
223
from pathlib import Path
4+
from typing import Any, Dict, List, Optional, Union
5+
6+
import requests
237
from pydantic import ValidationError
248

9+
from .models import (CollectionIn, CollectionOut, DocumentIn, DocumentInPatch,
10+
DocumentOut, EmbeddingsIn, EmbeddingsOut, FileOut,
11+
GenericError, GenericMessage, PatchCollectionIn,
12+
QueryFilter, QueryIn, QueryOut, TaskEnum)
13+
2514

2615
class ColiVara:
2716
def __init__(self, base_url: Optional[str] = None, api_key: Optional[str] = None):
@@ -479,6 +468,76 @@ def search(
479468
else:
480469
response.raise_for_status()
481470

471+
def filter(
472+
self,
473+
query_filter: Dict[str, Any],
474+
expand: Optional[str] = None,
475+
) -> List[Union[DocumentOut, CollectionOut]]:
476+
"""
477+
Filter for documents and collections that meet the criteria of the filter.
478+
479+
Args:
480+
query_filter (Dict[str, Any]): A dictionary specifying the filter criteria.
481+
The filter can be used to narrow down the search based on specific criteria.
482+
The dictionary should contain the following keys:
483+
- "on": "document" or "collection"
484+
- "key": str or List[str]
485+
- "value": Optional[Union[str, int, float, bool]]
486+
- "lookup": One of "key_lookup", "contains", "contained_by", "has_key", "has_keys", "has_any_keys"
487+
expand (Optional[str]): A comma-separated list of fields to expand in the response.
488+
Currently, only "pages" is supported, the document's pages will be included if provided.
489+
490+
491+
Returns:
492+
DocumentOut: The retrieved documents with their details.
493+
CollectionOut: The retrieved collections with their details.
494+
495+
Raises:
496+
ValueError: If the query_filter is invalid.
497+
requests.HTTPError: If the API request fails.
498+
499+
Example:
500+
# Simple filter
501+
results = client.filter({
502+
"on": "document",
503+
"key": "category",
504+
"value": "AI",
505+
"lookup": "contains"
506+
})
507+
508+
# Filter with a list of keys
509+
results = client.filter({
510+
"on": "collection",
511+
"key": ["tag1", "tag2"],
512+
"lookup": "has_keys"
513+
})
514+
"""
515+
516+
request_url = f"{self.base_url}/v1/filter/"
517+
518+
try:
519+
filter_obj = QueryFilter(**query_filter)
520+
payload = filter_obj.model_dump()
521+
except ValidationError as e:
522+
raise ValueError(f"Invalid query_filter: {str(e)}")
523+
524+
params = {"expand": expand}
525+
526+
response = requests.post(
527+
request_url, json=payload, params=params, headers=self.headers
528+
)
529+
530+
if response.status_code == 200:
531+
if query_filter["on"] == "document":
532+
return [DocumentOut(**doc) for doc in response.json()]
533+
else:
534+
return [CollectionOut(**col) for col in response.json()]
535+
elif response.status_code == 503:
536+
error = GenericError(**response.json())
537+
raise ValueError(f"Service unavailable: {error.detail}")
538+
else:
539+
response.raise_for_status()
540+
482541
def file_to_imgbase64(self, file_path: str) -> List[FileOut]:
483542
"""
484543
Converts a file to a list of base64 encoded images.

tests/test_client.py

Lines changed: 176 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,17 @@
1+
import base64
12
import os
3+
from pathlib import Path
4+
25
import pytest
3-
import base64
4-
from colivara_py import ColiVara, AsyncColiVara
5-
from colivara_py.models import (
6-
CollectionOut,
7-
DocumentOut,
8-
QueryOut,
9-
PageOutQuery,
10-
FileOut,
11-
EmbeddingsOut,
12-
PatchCollectionIn,
13-
DocumentIn,
14-
DocumentInPatch,
15-
QueryFilter,
16-
GenericMessage,
17-
)
186
import responses
19-
from requests.exceptions import HTTPError
207
from pydantic import ValidationError
21-
from pathlib import Path
8+
from requests.exceptions import HTTPError
9+
10+
from colivara_py import AsyncColiVara, ColiVara
11+
from colivara_py.models import (CollectionOut, DocumentIn, DocumentInPatch,
12+
DocumentOut, EmbeddingsOut, FileOut,
13+
GenericMessage, PageOutQuery,
14+
PatchCollectionIn, QueryFilter, QueryOut)
2215

2316

2417
def test_colivara_init_no_api_key():
@@ -988,6 +981,113 @@ def test_search_with_filter(api_key):
988981
assert result.results[0].document_metadata["category"] == "AI"
989982

990983

984+
@responses.activate
985+
def test_filter_documents(api_key):
986+
os.environ["COLIVARA_API_KEY"] = api_key
987+
base_url = "https://api.test.com"
988+
client = ColiVara(base_url=base_url)
989+
990+
expected_out = [
991+
{
992+
"id": 1,
993+
"name": "Test Document Fixture",
994+
"metadata": {"important": True},
995+
"url": "https://www.example.com",
996+
"num_pages": 1,
997+
"collection_name": "Test Collection Fixture",
998+
}
999+
]
1000+
1001+
responses.add(
1002+
responses.POST, f"{client.base_url}/v1/filter/", json=expected_out, status=200
1003+
)
1004+
1005+
query_filter = {
1006+
"on": "document",
1007+
"key": "important",
1008+
"value": True,
1009+
}
1010+
result = client.filter(query_filter=query_filter)
1011+
assert isinstance(result, list)
1012+
assert len(result) == 1
1013+
1014+
1015+
@responses.activate
1016+
def test_filter_documents_expand(api_key):
1017+
os.environ["COLIVARA_API_KEY"] = api_key
1018+
base_url = "https://api.test.com"
1019+
client = ColiVara(base_url=base_url)
1020+
1021+
expected_out = [
1022+
{
1023+
"id": 1,
1024+
"name": "Test Document Fixture",
1025+
"metadata": {"important": True},
1026+
"url": "https://www.example.com",
1027+
"num_pages": 1,
1028+
"collection_name": "Test Collection Fixture",
1029+
"pages": [
1030+
{
1031+
"document_name": "Test Document Fixture",
1032+
"img_base64": "base64_string",
1033+
"page_number": 1,
1034+
}
1035+
],
1036+
}
1037+
]
1038+
1039+
responses.add(
1040+
responses.POST,
1041+
f"{client.base_url}/v1/filter/?expand=pages",
1042+
json=expected_out,
1043+
status=200,
1044+
)
1045+
1046+
query_filter = {
1047+
"on": "document",
1048+
"key": "important",
1049+
"value": True,
1050+
}
1051+
result = client.filter(query_filter=query_filter, expand="pages")
1052+
assert isinstance(result, list)
1053+
assert len(result) == 1
1054+
1055+
1056+
@responses.activate
1057+
def test_filter_collections(api_key):
1058+
os.environ["COLIVARA_API_KEY"] = api_key
1059+
base_url = "https://api.test.com"
1060+
client = ColiVara(base_url=base_url)
1061+
1062+
expected_out = [
1063+
{
1064+
"id": 1,
1065+
"name": "test_collection",
1066+
"metadata": {"description": "A test collection"},
1067+
"num_documents": 2,
1068+
},
1069+
{
1070+
"id": 2,
1071+
"name": "another_test_collection",
1072+
"metadata": {"description": "Another test collection"},
1073+
"num_documents": 3,
1074+
},
1075+
]
1076+
1077+
responses.add(
1078+
responses.POST, f"{client.base_url}/v1/filter/", json=expected_out, status=200
1079+
)
1080+
1081+
query_filter = {
1082+
"on": "collection",
1083+
"key": "important",
1084+
"value": True,
1085+
}
1086+
result = client.filter(query_filter=query_filter, expand="pages")
1087+
assert isinstance(result, list)
1088+
assert len(result) == 2
1089+
1090+
9911091
@responses.activate
9921092
def test_search_service_unavailable(api_key):
9931093
os.environ["COLIVARA_API_KEY"] = api_key
@@ -1006,6 +1106,31 @@ def test_search_service_unavailable(api_key):
10061106
assert "Service unavailable" in str(exc_info.value)
10071107

10081108

1109+
@responses.activate
1110+
def test_filter_service_unavailable(api_key):
1111+
os.environ["COLIVARA_API_KEY"] = api_key
1112+
base_url = "https://api.test.com"
1113+
client = ColiVara(base_url=base_url)
1114+
1115+
error_response = {"detail": "Service is temporarily unavailable"}
1116+
1117+
responses.add(
1118+
responses.POST, f"{client.base_url}/v1/filter/", json=error_response, status=503
1119+
)
1120+
1121+
with pytest.raises(ValueError) as exc_info:
1122+
client.filter(
1123+
query_filter={
1124+
"on": "document",
1125+
"key": "category",
1126+
"value": "AI",
1127+
"lookup": "contains",
1128+
}
1129+
)
1130+
1131+
assert "Service unavailable" in str(exc_info.value)
1132+
1133+
10091134
@responses.activate
10101135
def test_search_invalid_filter(api_key):
10111136
os.environ["COLIVARA_API_KEY"] = api_key
@@ -1018,6 +1143,18 @@ def test_search_invalid_filter(api_key):
10181143
assert "Invalid query_filter" in str(exc_info.value)
10191144

10201145

1146+
@responses.activate
1147+
def test_filter_invalid_filter(api_key):
1148+
os.environ["COLIVARA_API_KEY"] = api_key
1149+
base_url = "https://api.test.com"
1150+
client = ColiVara(base_url=base_url)
1151+
1152+
with pytest.raises(ValueError) as exc_info:
1153+
client.filter(query_filter={"invalid": "filter"})
1154+
1155+
assert "Invalid query_filter" in str(exc_info.value)
1156+
1157+
10211158
@responses.activate
10221159
def test_search_http_error(api_key):
10231160
os.environ["COLIVARA_API_KEY"] = api_key
@@ -1033,6 +1170,28 @@ def test_search_http_error(api_key):
10331170
client.search("what is 1+1?")
10341171

10351172

1173+
@responses.activate
1174+
def test_filter_http_error(api_key):
1175+
os.environ["COLIVARA_API_KEY"] = api_key
1176+
base_url = "https://api.test.com"
1177+
client = ColiVara(base_url=base_url)
1178+
responses.add(
1179+
responses.POST,
1180+
f"{client.base_url}/v1/filter/",
1181+
json={"error": "Internal Server Error"},
1182+
status=500,
1183+
)
1184+
with pytest.raises(HTTPError):
1185+
client.filter(
1186+
query_filter={
1187+
"on": "document",
1188+
"key": "category",
1189+
"value": "AI",
1190+
"lookup": "contains",
1191+
}
1192+
)
1193+
1194+
10361195
@pytest.fixture
10371196
def test_file_path(tmp_path):
10381197
file_content = b"Test file content"

0 commit comments

Comments
 (0)