1+ import os
2+ from typing import List
3+ import boto3
4+ import json
5+ from deepsearcher .embedding .base import BaseEmbedding
6+
7+ MODEL_ID_TITAN_TEXT_G1 = "amazon.titan-embed-text-v1"
8+ MODEL_ID_TITAN_TEXT_V2 = "amazon.titan-embed-text-v2:0"
9+ MODEL_ID_TITAN_MULTIMODAL_G1 = "amazon.titan-embed-image-v1"
10+ MODEL_ID_COHERE_ENGLISH_V3 = "cohere.embed-english-v3"
11+ MODEL_ID_COHERE_MULTILINGUAL_V3 = "cohere.embed-multilingual-v3"
12+
13+ BEDROCK_MODEL_DIM_MAP = {
14+ MODEL_ID_TITAN_TEXT_G1 : 1536 ,
15+ MODEL_ID_TITAN_TEXT_V2 : 1024 ,
16+ MODEL_ID_TITAN_MULTIMODAL_G1 : 1024 ,
17+ MODEL_ID_COHERE_ENGLISH_V3 : 1024 ,
18+ MODEL_ID_COHERE_MULTILINGUAL_V3 : 1024
19+ }
20+
21+ DEFAULT_MODEL_ID = MODEL_ID_TITAN_TEXT_V2
22+
23+ class BedrockEmbedding (BaseEmbedding ):
24+ def __init__ (self , model : str = DEFAULT_MODEL_ID , ** kwargs ):
25+ """
26+ Args:
27+ model_name (`str`):
28+ Can be one of the following:
29+ 'amazon.titan-embed-text-v2:0': dimensions include 256, 512, 1024, default is 1024,
30+ """
31+
32+ aws_access_key_id = kwargs .pop ("aws_access_key_id" , os .getenv ("AWS_ACCESS_KEY_ID" ))
33+ aws_secret_access_key = kwargs .pop ("aws_secret_access_key" , os .getenv ("AWS_SECRET_ACCESS_KEY" ))
34+
35+ if model in {None , DEFAULT_MODEL_ID } and "model_name" in kwargs :
36+ model = kwargs .pop ("model_name" ) #overwrites `model` with `model_name`
37+
38+ self .model = model
39+
40+ #TODO: initiate boto3 client
41+ self .client = boto3 .client ("bedrock-runtime" ,
42+ region_name = "us-east-1" , #FIXME: allow users to specify
43+ aws_access_key_id = aws_access_key_id ,
44+ aws_secret_access_key = aws_secret_access_key )
45+
46+ def embed_query (self , text : str ) -> List [float ]:
47+ response = self .client .invoke_model (modelId = self .model , body = json .dumps ({"inputText" : text }))
48+ model_response = json .loads (response ["body" ].read ())
49+ embedding = model_response ["embedding" ]
50+ return embedding
51+
52+ def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
53+ return [self .embed_query (text ) for text in texts ]
54+
55+ @property
56+ def dimension (self ) -> int :
57+ return BEDROCK_MODEL_DIM_MAP [self .model ]
0 commit comments