55from httpx import AsyncClient , Headers
66from yarl import URL
77
8- from ..exceptions import VectorBucketException
8+ from ..exceptions import StorageApiError , VectorBucketException
99from ..types import (
1010 JSON ,
1111 DistanceMetric ,
12+ GetVectorBucketResponse ,
13+ GetVectorIndexResponse ,
1214 GetVectorsResponse ,
13- ListIndexesResponse ,
15+ ListVectorBucketsResponse ,
16+ ListVectorIndexesResponse ,
1417 ListVectorsResponse ,
1518 MetadataConfiguration ,
1619 QueryVectorsResponse ,
20+ VectorBucket ,
1721 VectorData ,
1822 VectorFilter ,
1923 VectorIndex ,
24+ VectorMatch ,
2025 VectorObject ,
2126)
22- from .request import RequestBuilder
27+ from .request import AsyncRequestBuilder
28+
29+
30+ # used to not send non-required values as `null`
31+ # for they cannot be null
32+ def remove_none (** kwargs : JSON ) -> JSON :
33+ return {key : val for key , val in kwargs .items () if val is not None }
2334
2435
2536class AsyncVectorBucketScope :
26- def __init__ (self , request : RequestBuilder , bucket_name : str ) -> None :
37+ def __init__ (self , request : AsyncRequestBuilder , bucket_name : str ) -> None :
2738 self ._request = request
2839 self ._bucket_name = bucket_name
2940
3041 def with_metadata (self , ** data : JSON ) -> JSON :
31- return { " vectorBucketName" : self ._bucket_name , ** data }
42+ return remove_none ( vectorBucketName = self ._bucket_name , ** data )
3243
3344 async def create_index (
3445 self ,
@@ -47,26 +58,29 @@ async def create_index(
4758 )
4859 await self ._request .send (http_method = "POST" , path = ["CreateIndex" ], body = body )
4960
50- async def get_index (self , index_name : str ) -> VectorIndex :
61+ async def get_index (self , index_name : str ) -> Optional [ VectorIndex ] :
5162 body = self .with_metadata (indexName = index_name )
52- data = await self ._request .send (
53- http_method = "POST" , path = ["GetIndex" ], body = body
54- )
55- return VectorIndex .model_validate (data .content )
63+ try :
64+ data = await self ._request .send (
65+ http_method = "POST" , path = ["GetIndex" ], body = body
66+ )
67+ return GetVectorIndexResponse .model_validate_json (data .content ).index
68+ except StorageApiError :
69+ return None
5670
5771 async def list_indexes (
5872 self ,
5973 next_token : Optional [str ] = None ,
6074 max_results : Optional [int ] = None ,
6175 prefix : Optional [str ] = None ,
62- ) -> ListIndexesResponse :
76+ ) -> ListVectorIndexesResponse :
6377 body = self .with_metadata (
6478 next_token = next_token , max_results = max_results , prefix = prefix
6579 )
6680 data = await self ._request .send (
6781 http_method = "POST" , path = ["ListIndexes" ], body = body
6882 )
69- return ListIndexesResponse . model_validate (data .content )
83+ return ListVectorIndexesResponse . model_validate_json (data .content )
7084
7185 async def delete_index (self , index_name : str ) -> None :
7286 body = self .with_metadata (indexName = index_name )
@@ -78,33 +92,33 @@ def index(self, index_name: str) -> AsyncVectorIndexScope:
7892
7993class AsyncVectorIndexScope :
8094 def __init__ (
81- self , request : RequestBuilder , bucket_name : str , index_name : str
95+ self , request : AsyncRequestBuilder , bucket_name : str , index_name : str
8296 ) -> None :
8397 self ._request = request
8498 self ._bucket_name = bucket_name
8599 self ._index_name = index_name
86100
87101 def with_metadata (self , ** data : JSON ) -> JSON :
88- return {
89- " vectorBucketName" : self ._bucket_name ,
90- " indexName" : self ._index_name ,
102+ return remove_none (
103+ vectorBucketName = self ._bucket_name ,
104+ indexName = self ._index_name ,
91105 ** data ,
92- }
106+ )
93107
94108 async def put (self , vectors : List [VectorObject ]) -> None :
95109 body = self .with_metadata (vectors = [v .as_json () for v in vectors ])
96110 await self ._request .send (http_method = "POST" , path = ["PutVectors" ], body = body )
97111
98112 async def get (
99113 self , * keys : str , return_data : bool = True , return_metadata : bool = True
100- ) -> GetVectorsResponse :
114+ ) -> List [ VectorMatch ] :
101115 body = self .with_metadata (
102116 keys = keys , returnData = return_data , returnMetadata = return_metadata
103117 )
104118 data = await self ._request .send (
105119 http_method = "POST" , path = ["GetVectors" ], body = body
106120 )
107- return GetVectorsResponse .model_validate (data .content )
121+ return GetVectorsResponse .model_validate_json (data .content ). vectors
108122
109123 async def list (
110124 self ,
@@ -126,7 +140,7 @@ async def list(
126140 data = await self ._request .send (
127141 http_method = "POST" , path = ["ListVectors" ], body = body
128142 )
129- return ListVectorsResponse .model_validate (data .content )
143+ return ListVectorsResponse .model_validate_json (data .content )
130144
131145 async def query (
132146 self ,
@@ -146,7 +160,7 @@ async def query(
146160 data = await self ._request .send (
147161 http_method = "POST" , path = ["QueryVectors" ], body = body
148162 )
149- return QueryVectorsResponse .model_validate (data .content )
163+ return QueryVectorsResponse .model_validate_json (data .content )
150164
151165 async def delete (self , keys : List [str ]) -> None :
152166 if 1 < len (keys ) or len (keys ) > 500 :
@@ -157,7 +171,7 @@ async def delete(self, keys: List[str]) -> None:
157171
158172class AsyncStorageVectorsClient :
159173 def __init__ (self , url : URL , headers : Headers , session : AsyncClient ) -> None :
160- self ._request = RequestBuilder (session , base_url = URL (url ), headers = headers )
174+ self ._request = AsyncRequestBuilder (session , base_url = URL (url ), headers = headers )
161175
162176 def from_ (self , bucket_name : str ) -> AsyncVectorBucketScope :
163177 return AsyncVectorBucketScope (self ._request , bucket_name )
@@ -168,7 +182,32 @@ async def create_bucket(self, bucket_name: str) -> None:
168182 http_method = "POST" , path = ["CreateVectorBucket" ], body = body
169183 )
170184
171- # async def get_bucket(self, bucket_name: str) -> GetBucketResponse:
172- # body = { 'vectorBucketName': bucket_name }
173- # data = await self._request.send(http_method='POST', path=['GetVectorBucket'], body=body)
174- # return GetVectorsResponse.model_validate(data.content)
185+ async def get_bucket (self , bucket_name : str ) -> Optional [VectorBucket ]:
186+ body = {"vectorBucketName" : bucket_name }
187+ try :
188+ data = await self ._request .send (
189+ http_method = "POST" , path = ["GetVectorBucket" ], body = body
190+ )
191+ return GetVectorBucketResponse .model_validate_json (
192+ data .content
193+ ).vectorBucket
194+ except StorageApiError :
195+ return None
196+
197+ async def list_buckets (
198+ self ,
199+ prefix : Optional [str ] = None ,
200+ max_results : Optional [int ] = None ,
201+ next_token : Optional [str ] = None ,
202+ ) -> ListVectorBucketsResponse :
203+ body = remove_none (prefix = prefix , maxResults = max_results , nextToken = next_token )
204+ data = await self ._request .send (
205+ http_method = "POST" , path = ["ListVectorBuckets" ], body = body
206+ )
207+ return ListVectorBucketsResponse .model_validate_json (data .content )
208+
209+ async def delete_bucket (self , bucket_name : str ) -> None :
210+ body = {"vectorBucketName" : bucket_name }
211+ await self ._request .send (
212+ http_method = "POST" , path = ["DeleteVectorBucket" ], body = body
213+ )
0 commit comments