Skip to content

Commit dac6d41

Browse files
hiimivantangivantang
andauthored
Added support for Amazon Bedrock embeddings. (#41)
* added support for Bedrock embedding models * added support for Bedrock embedding models --------- Co-authored-by: ivantang <zilliz@ivantang-work.lan>
1 parent 66061a9 commit dac6d41

File tree

4 files changed

+61
-2
lines changed

4 files changed

+61
-2
lines changed

deepsearcher/embedding/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from .milvus_embedding import MilvusEmbedding
22
from .openai_embedding import OpenAIEmbedding
33
from .voyage_embedding import VoyageEmbedding
4+
from .bedrock_embedding import BedrockEmbedding
45

56
__all__ = [
67
"MilvusEmbedding",
78
"OpenAIEmbedding",
89
"VoyageEmbedding",
9-
]
10+
"BedrockEmbedding"
11+
]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import os
2+
from typing import List
3+
import boto3
4+
import json
5+
from deepsearcher.embedding.base import BaseEmbedding
6+
7+
MODEL_ID_TITAN_TEXT_G1 = "amazon.titan-embed-text-v1"
8+
MODEL_ID_TITAN_TEXT_V2 = "amazon.titan-embed-text-v2:0"
9+
MODEL_ID_TITAN_MULTIMODAL_G1 = "amazon.titan-embed-image-v1"
10+
MODEL_ID_COHERE_ENGLISH_V3 = "cohere.embed-english-v3"
11+
MODEL_ID_COHERE_MULTILINGUAL_V3 = "cohere.embed-multilingual-v3"
12+
13+
BEDROCK_MODEL_DIM_MAP = {
14+
MODEL_ID_TITAN_TEXT_G1: 1536,
15+
MODEL_ID_TITAN_TEXT_V2: 1024,
16+
MODEL_ID_TITAN_MULTIMODAL_G1: 1024,
17+
MODEL_ID_COHERE_ENGLISH_V3: 1024,
18+
MODEL_ID_COHERE_MULTILINGUAL_V3: 1024
19+
}
20+
21+
DEFAULT_MODEL_ID = MODEL_ID_TITAN_TEXT_V2
22+
23+
class BedrockEmbedding(BaseEmbedding):
24+
def __init__(self, model: str = DEFAULT_MODEL_ID, **kwargs):
25+
"""
26+
Args:
27+
model_name (`str`):
28+
Can be one of the following:
29+
'amazon.titan-embed-text-v2:0': dimensions include 256, 512, 1024, default is 1024,
30+
"""
31+
32+
aws_access_key_id = kwargs.pop("aws_access_key_id", os.getenv("AWS_ACCESS_KEY_ID"))
33+
aws_secret_access_key = kwargs.pop("aws_secret_access_key", os.getenv("AWS_SECRET_ACCESS_KEY"))
34+
35+
if model in {None, DEFAULT_MODEL_ID} and "model_name" in kwargs:
36+
model = kwargs.pop("model_name") #overwrites `model` with `model_name`
37+
38+
self.model = model
39+
40+
#TODO: initiate boto3 client
41+
self.client = boto3.client("bedrock-runtime",
42+
region_name="us-east-1", #FIXME: allow users to specify
43+
aws_access_key_id=aws_access_key_id,
44+
aws_secret_access_key=aws_secret_access_key)
45+
46+
def embed_query(self, text: str) -> List[float]:
47+
response = self.client.invoke_model(modelId=self.model, body=json.dumps({"inputText": text}))
48+
model_response = json.loads(response["body"].read())
49+
embedding = model_response["embedding"]
50+
return embedding
51+
52+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
53+
return [self.embed_query(text) for text in texts]
54+
55+
@property
56+
def dimension(self) -> int:
57+
return BEDROCK_MODEL_DIM_MAP[self.model]

deepsearcher/online_query.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def query(
2020
) -> Tuple[str, List[RetrievalResult], int]:
2121
return asyncio.run(async_query(original_query, max_iter))
2222

23-
2423
async def async_query(
2524
original_query: str, max_iter: int = 3
2625
) -> Tuple[str, List[RetrievalResult], int]:

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ termcolor
1010
fastapi
1111
uvicorn
1212
pydantic-settings
13+
boto3

0 commit comments

Comments
 (0)