Skip to content

Commit 8b6e3cb

Browse files
committed
fix(storage): fix more stuff
1 parent bf9384f commit 8b6e3cb

File tree

6 files changed

+138
-47
lines changed

6 files changed

+138
-47
lines changed

src/storage/src/storage3/_async/analytics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99
SortColumn,
1010
SortOrder,
1111
)
12-
from .request import RequestBuilder
12+
from .request import AsyncRequestBuilder
1313

1414

1515
class AsyncStorageAnalyticsClient:
16-
def __init__(self, request: RequestBuilder) -> None:
16+
def __init__(self, request: AsyncRequestBuilder) -> None:
1717
self._request = request
1818

1919
async def create(self, bucket_name: str) -> AnalyticsBucket:
Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
from typing import Optional
22

3-
from httpx import AsyncClient, Headers, QueryParams, Response
3+
from httpx import AsyncClient, Headers, HTTPStatusError, QueryParams, Response
4+
from pydantic import ValidationError
45
from yarl import URL
56

7+
from ..exceptions import StorageApiError, VectorBucketErrorMessage
68
from ..types import JSON, RequestMethod
79

810

9-
class RequestBuilder:
11+
class AsyncRequestBuilder:
1012
def __init__(self, session: AsyncClient, base_url: URL, headers: Headers) -> None:
1113
self._session = session
1214
self._base_url = base_url
@@ -19,12 +21,27 @@ async def send(
1921
body: JSON = None,
2022
query_params: Optional[QueryParams] = None,
2123
) -> Response:
22-
data = await self._session.request(
24+
response = await self._session.request(
2325
method=http_method,
2426
json=body,
2527
url=str(self._base_url.joinpath(*path)),
2628
headers=self.headers,
2729
params=query_params or QueryParams(),
2830
)
29-
data.raise_for_status()
30-
return data
31+
try:
32+
response.raise_for_status()
33+
return response
34+
except HTTPStatusError as exc:
35+
try:
36+
error = VectorBucketErrorMessage.model_validate_json(response.content)
37+
raise StorageApiError(
38+
message=error.message,
39+
code=error.code or "400",
40+
status=error.statusCode,
41+
) from exc
42+
except ValidationError as exc:
43+
raise StorageApiError(
44+
message="The request failed, but could not parse error message response.",
45+
code="LibraryError",
46+
status=response.status_code,
47+
) from exc

src/storage/src/storage3/_async/vectors.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,41 @@
55
from httpx import AsyncClient, Headers
66
from yarl import URL
77

8-
from ..exceptions import VectorBucketException
8+
from ..exceptions import StorageApiError, VectorBucketException
99
from ..types import (
1010
JSON,
1111
DistanceMetric,
12+
GetVectorBucketResponse,
13+
GetVectorIndexResponse,
1214
GetVectorsResponse,
13-
ListIndexesResponse,
15+
ListVectorBucketsResponse,
16+
ListVectorIndexesResponse,
1417
ListVectorsResponse,
1518
MetadataConfiguration,
1619
QueryVectorsResponse,
20+
VectorBucket,
1721
VectorData,
1822
VectorFilter,
1923
VectorIndex,
24+
VectorMatch,
2025
VectorObject,
2126
)
22-
from .request import RequestBuilder
27+
from .request import AsyncRequestBuilder
28+
29+
30+
# used to not send non-required values as `null`
31+
# for they cannot be null
32+
def remove_none(**kwargs: JSON) -> JSON:
33+
return {key: val for key, val in kwargs.items() if val is not None}
2334

2435

2536
class AsyncVectorBucketScope:
26-
def __init__(self, request: RequestBuilder, bucket_name: str) -> None:
37+
def __init__(self, request: AsyncRequestBuilder, bucket_name: str) -> None:
2738
self._request = request
2839
self._bucket_name = bucket_name
2940

3041
def with_metadata(self, **data: JSON) -> JSON:
31-
return {"vectorBucketName": self._bucket_name, **data}
42+
return remove_none(vectorBucketName=self._bucket_name, **data)
3243

3344
async def create_index(
3445
self,
@@ -47,26 +58,29 @@ async def create_index(
4758
)
4859
await self._request.send(http_method="POST", path=["CreateIndex"], body=body)
4960

50-
async def get_index(self, index_name: str) -> VectorIndex:
61+
async def get_index(self, index_name: str) -> Optional[VectorIndex]:
5162
body = self.with_metadata(indexName=index_name)
52-
data = await self._request.send(
53-
http_method="POST", path=["GetIndex"], body=body
54-
)
55-
return VectorIndex.model_validate(data.content)
63+
try:
64+
data = await self._request.send(
65+
http_method="POST", path=["GetIndex"], body=body
66+
)
67+
return GetVectorIndexResponse.model_validate_json(data.content).index
68+
except StorageApiError:
69+
return None
5670

5771
async def list_indexes(
5872
self,
5973
next_token: Optional[str] = None,
6074
max_results: Optional[int] = None,
6175
prefix: Optional[str] = None,
62-
) -> ListIndexesResponse:
76+
) -> ListVectorIndexesResponse:
6377
body = self.with_metadata(
6478
next_token=next_token, max_results=max_results, prefix=prefix
6579
)
6680
data = await self._request.send(
6781
http_method="POST", path=["ListIndexes"], body=body
6882
)
69-
return ListIndexesResponse.model_validate(data.content)
83+
return ListVectorIndexesResponse.model_validate_json(data.content)
7084

7185
async def delete_index(self, index_name: str) -> None:
7286
body = self.with_metadata(indexName=index_name)
@@ -78,33 +92,33 @@ def index(self, index_name: str) -> AsyncVectorIndexScope:
7892

7993
class AsyncVectorIndexScope:
8094
def __init__(
81-
self, request: RequestBuilder, bucket_name: str, index_name: str
95+
self, request: AsyncRequestBuilder, bucket_name: str, index_name: str
8296
) -> None:
8397
self._request = request
8498
self._bucket_name = bucket_name
8599
self._index_name = index_name
86100

87101
def with_metadata(self, **data: JSON) -> JSON:
88-
return {
89-
"vectorBucketName": self._bucket_name,
90-
"indexName": self._index_name,
102+
return remove_none(
103+
vectorBucketName=self._bucket_name,
104+
indexName=self._index_name,
91105
**data,
92-
}
106+
)
93107

94108
async def put(self, vectors: List[VectorObject]) -> None:
95109
body = self.with_metadata(vectors=[v.as_json() for v in vectors])
96110
await self._request.send(http_method="POST", path=["PutVectors"], body=body)
97111

98112
async def get(
99113
self, *keys: str, return_data: bool = True, return_metadata: bool = True
100-
) -> GetVectorsResponse:
114+
) -> List[VectorMatch]:
101115
body = self.with_metadata(
102116
keys=keys, returnData=return_data, returnMetadata=return_metadata
103117
)
104118
data = await self._request.send(
105119
http_method="POST", path=["GetVectors"], body=body
106120
)
107-
return GetVectorsResponse.model_validate(data.content)
121+
return GetVectorsResponse.model_validate_json(data.content).vectors
108122

109123
async def list(
110124
self,
@@ -126,7 +140,7 @@ async def list(
126140
data = await self._request.send(
127141
http_method="POST", path=["ListVectors"], body=body
128142
)
129-
return ListVectorsResponse.model_validate(data.content)
143+
return ListVectorsResponse.model_validate_json(data.content)
130144

131145
async def query(
132146
self,
@@ -146,7 +160,7 @@ async def query(
146160
data = await self._request.send(
147161
http_method="POST", path=["QueryVectors"], body=body
148162
)
149-
return QueryVectorsResponse.model_validate(data.content)
163+
return QueryVectorsResponse.model_validate_json(data.content)
150164

151165
async def delete(self, keys: List[str]) -> None:
152166
if 1 < len(keys) or len(keys) > 500:
@@ -157,7 +171,7 @@ async def delete(self, keys: List[str]) -> None:
157171

158172
class AsyncStorageVectorsClient:
159173
def __init__(self, url: URL, headers: Headers, session: AsyncClient) -> None:
160-
self._request = RequestBuilder(session, base_url=URL(url), headers=headers)
174+
self._request = AsyncRequestBuilder(session, base_url=URL(url), headers=headers)
161175

162176
def from_(self, bucket_name: str) -> AsyncVectorBucketScope:
163177
return AsyncVectorBucketScope(self._request, bucket_name)
@@ -168,7 +182,32 @@ async def create_bucket(self, bucket_name: str) -> None:
168182
http_method="POST", path=["CreateVectorBucket"], body=body
169183
)
170184

171-
# async def get_bucket(self, bucket_name: str) -> GetBucketResponse:
172-
# body = { 'vectorBucketName': bucket_name }
173-
# data = await self._request.send(http_method='POST', path=['GetVectorBucket'], body=body)
174-
# return GetVectorsResponse.model_validate(data.content)
185+
async def get_bucket(self, bucket_name: str) -> Optional[VectorBucket]:
186+
body = {"vectorBucketName": bucket_name}
187+
try:
188+
data = await self._request.send(
189+
http_method="POST", path=["GetVectorBucket"], body=body
190+
)
191+
return GetVectorBucketResponse.model_validate_json(
192+
data.content
193+
).vectorBucket
194+
except StorageApiError:
195+
return None
196+
197+
async def list_buckets(
198+
self,
199+
prefix: Optional[str] = None,
200+
max_results: Optional[int] = None,
201+
next_token: Optional[str] = None,
202+
) -> ListVectorBucketsResponse:
203+
body = remove_none(prefix=prefix, maxResults=max_results, nextToken=next_token)
204+
data = await self._request.send(
205+
http_method="POST", path=["ListVectorBuckets"], body=body
206+
)
207+
return ListVectorBucketsResponse.model_validate_json(data.content)
208+
209+
async def delete_bucket(self, bucket_name: str) -> None:
210+
body = {"vectorBucketName": bucket_name}
211+
await self._request.send(
212+
http_method="POST", path=["DeleteVectorBucket"], body=body
213+
)

src/storage/src/storage3/_sync/vectors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
JSON,
1111
DistanceMetric,
1212
GetVectorsResponse,
13-
ListIndexesResponse,
13+
ListVectorIndexesResponse,
1414
ListVectorsResponse,
1515
MetadataConfiguration,
1616
QueryVectorsResponse,
@@ -55,12 +55,12 @@ def list_indexes(
5555
next_token: Optional[str] = None,
5656
max_results: Optional[int] = None,
5757
prefix: Optional[str] = None,
58-
) -> ListIndexesResponse:
58+
) -> ListVectorIndexesResponse:
5959
body = self.with_metadata(
6060
next_token=next_token, max_results=max_results, prefix=prefix
6161
)
6262
data = self._request.send(http_method="POST", path=["ListIndexes"], body=body)
63-
return ListIndexesResponse.model_validate(data.content)
63+
return ListVectorIndexesResponse.model_validate(data.content)
6464

6565
def delete_index(self, index_name: str) -> None:
6666
body = self.with_metadata(indexName=index_name)

src/storage/src/storage3/exceptions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import TypedDict
1+
from typing import Optional, TypedDict, Union
2+
3+
from pydantic import BaseModel
24

35
from .utils import StorageException
46

@@ -8,17 +10,24 @@ def __init__(self, msg: str) -> None:
810
self.msg = msg
911

1012

13+
class VectorBucketErrorMessage(BaseModel):
14+
statusCode: Union[str, int]
15+
error: str
16+
message: str
17+
code: Optional[str] = None
18+
19+
1120
class StorageApiErrorDict(TypedDict):
1221
name: str
1322
message: str
1423
code: str
15-
status: int
24+
status: Union[int, str]
1625

1726

1827
class StorageApiError(StorageException):
1928
"""Error raised when an operation on the storage API fails."""
2029

21-
def __init__(self, message: str, code: str, status: int) -> None:
30+
def __init__(self, message: str, code: str, status: Union[int, str]) -> None:
2231
error_message = (
2332
f"{{'statusCode': {status}, 'error': {code}, 'message': {message}}}"
2433
)

src/storage/src/storage3/types.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,11 @@ class CreateSignedUploadUrlOptions(BaseModel):
159159
total=False,
160160
)
161161

162-
DistanceMetric: TypeAlias = Literal["cosine", "euclidean", "dotproduct"]
162+
DistanceMetric: TypeAlias = Literal["cosine", "euclidean"]
163163

164164

165165
class MetadataConfiguration(BaseModel):
166-
non_filterable_metadata_keys: Optional[List[str]] = Field(
167-
alias="nonFilterableMetadaKeys", default=None
168-
)
166+
nonFilterableMetadataKeys: Optional[List[str]]
169167

170168

171169
class ListIndexesOptions(BaseModel):
@@ -178,9 +176,9 @@ class ListIndexesResponseItem(BaseModel):
178176
indexName: str
179177

180178

181-
class ListIndexesResponse(BaseModel):
179+
class ListVectorIndexesResponse(BaseModel):
182180
indexes: List[ListIndexesResponseItem]
183-
nextToken: Optional[str]
181+
nextToken: Optional[str] = None
184182

185183

186184
class VectorIndex(BaseModel):
@@ -195,6 +193,10 @@ class VectorIndex(BaseModel):
195193
creation_time: Optional[datetime] = None
196194

197195

196+
class GetVectorIndexResponse(BaseModel):
197+
index: VectorIndex
198+
199+
198200
VectorFilter = Dict[str, Any]
199201

200202

@@ -205,7 +207,7 @@ class VectorData(BaseModel):
205207
class VectorObject(BaseModel):
206208
key: str
207209
data: VectorData
208-
metadata: Optional[dict[str, Any]] = None
210+
metadata: Optional[dict[str, Union[str, bool, float]]] = None
209211

210212
def as_json(self) -> JSON:
211213
return {"key": self.key, "data": dict(self.data), "metadata": self.metadata}
@@ -224,7 +226,7 @@ class GetVectorsResponse(BaseModel):
224226

225227
class ListVectorsResponse(BaseModel):
226228
vectors: List[VectorMatch]
227-
nextToken: Optional[str]
229+
nextToken: Optional[str] = None
228230

229231

230232
class QueryVectorsResponse(BaseModel):
@@ -247,3 +249,27 @@ class AnalyticsBucket(BaseModel):
247249

248250
class AnalyticsBucketDeleteResponse(BaseModel):
249251
message: str
252+
253+
254+
class VectorBucketEncryptionConfiguration(BaseModel):
255+
kmsKeyArn: Optional[str] = None
256+
sseType: Optional[str] = None
257+
258+
259+
class VectorBucket(BaseModel):
260+
vectorBucketName: str
261+
creationTime: Optional[datetime] = None
262+
encryptionConfiguration: Optional[VectorBucketEncryptionConfiguration] = None
263+
264+
265+
class GetVectorBucketResponse(BaseModel):
266+
vectorBucket: VectorBucket
267+
268+
269+
class ListVectorBucketsItem(BaseModel):
270+
vectorBucketName: str
271+
272+
273+
class ListVectorBucketsResponse(BaseModel):
274+
vectorBuckets: List[ListVectorBucketsItem]
275+
nextToken: Optional[str] = None

0 commit comments

Comments
 (0)