Skip to content

Commit 0087f07

Browse files
committed
Add Amazon Bedrock Text vectorizer (#143)
1 parent acf3c66 commit 0087f07

File tree

8 files changed

+861
-408
lines changed

8 files changed

+861
-408
lines changed

conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ def gcp_location():
7171
def gcp_project_id():
7272
return os.getenv("GCP_PROJECT_ID")
7373

74+
@pytest.fixture
75+
def aws_credentials():
76+
return {
77+
"aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
78+
"aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
79+
"region_name": os.getenv("AWS_REGION", "us-east-1")
80+
}
7481

7582
@pytest.fixture
7683
def sample_data():

docs/api/vectorizer.rst

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,41 @@ CohereTextVectorizer
6161
:show-inheritance:
6262
:members:
6363

64+
BedrockTextVectorizer
65+
====================
66+
67+
.. _bedrocktextvectorizer_api:
68+
69+
.. currentmodule:: redisvl.utils.vectorize.text.bedrock
70+
71+
.. autoclass:: BedrockTextVectorizer
72+
:show-inheritance:
73+
:members:
74+
75+
The BedrockTextVectorizer class utilizes Amazon Bedrock's API to generate
76+
embeddings for text data. This vectorizer requires AWS credentials, which can be provided
77+
via environment variables (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION) or
78+
through the api_config parameter.
79+
80+
Example::
81+
82+
# Initialize with environment variables
83+
vectorizer = BedrockTextVectorizer(model_id="amazon.titan-embed-text-v2:0")
84+
85+
# Or with explicit credentials
86+
vectorizer = BedrockTextVectorizer(
87+
model_id="amazon.titan-embed-text-v2:0",
88+
api_config={
89+
"aws_access_key_id": "your_access_key",
90+
"aws_secret_access_key": "your_secret_key",
91+
"region_name": "us-east-1"
92+
}
93+
)
94+
95+
# Generate embeddings
96+
embedding = vectorizer.embed("Hello world")
97+
embeddings = vectorizer.embed_many(["Hello", "World"])
98+
6499

65100
CustomTextVectorizer
66101
====================

docs/user_guide/vectorizers_04.ipynb

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
"3. Vertex AI\n",
1414
"4. Cohere\n",
1515
"5. Mistral AI\n",
16-
"6. Bringing your own vectorizer\n",
16+
"6. Amazon Bedrock\n",
17+
"7. Bringing your own vectorizer\n",
1718
"\n",
1819
"Before running this notebook, be sure to\n",
1920
"1. Have installed ``redisvl`` and have that environment active for this notebook.\n",
@@ -541,6 +542,75 @@
541542
"# print(test[:10])"
542543
]
543544
},
545+
{
546+
"cell_type": "markdown",
547+
"metadata": {},
548+
"source": [
549+
"### Amazon Bedrock\n",
550+
"\n",
551+
"Amazon Bedrock provides fully managed foundation models for text embeddings. Install the required dependencies:\n",
552+
"\n",
553+
"```bash\n",
554+
"pip install 'redisvl[bedrock]' # Installs boto3\n",
555+
"```"
556+
]
557+
},
558+
{
559+
"cell_type": "markdown",
560+
"metadata": {},
561+
"source": [
562+
"#### Configure AWS credentials:"
563+
]
564+
},
565+
{
566+
"cell_type": "code",
567+
"execution_count": null,
568+
"metadata": {},
569+
"outputs": [],
570+
"source": [
571+
"import os\n",
572+
"import getpass\n",
573+
"\n",
574+
"# Either set environment variables AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION\n",
575+
"# Or configure directly:\n",
576+
"os.environ[\"AWS_ACCESS_KEY_ID\"] = getpass.getpass(\"Enter AWS Access Key ID: \")\n",
577+
"os.environ[\"AWS_SECRET_ACCESS_KEY\"] = getpass.getpass(\"Enter AWS Secret Key: \")\n",
578+
"os.environ[\"AWS_REGION\"] = \"us-east-1\" # Change as needed"
579+
]
580+
},
581+
{
582+
"cell_type": "markdown",
583+
"metadata": {},
584+
"source": [
585+
"#### Create embeddings:"
586+
]
587+
},
588+
{
589+
"cell_type": "code",
590+
"execution_count": null,
591+
"metadata": {},
592+
"outputs": [],
593+
"source": [
594+
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
595+
"\n",
596+
"bedrock = BedrockTextVectorizer(\n",
597+
" model_id=\"amazon.titan-embed-text-v2:0\"\n",
598+
")\n",
599+
"\n",
600+
"# Single embedding\n",
601+
"text = \"This is a test sentence.\"\n",
602+
"embedding = bedrock.embed(text)\n",
603+
"print(f\"Vector dimensions: {len(embedding)}\")\n",
604+
"\n",
605+
"# Multiple embeddings\n",
606+
"sentences = [\n",
607+
" \"That is a happy dog\",\n",
608+
" \"That is a happy person\",\n",
609+
" \"Today is a sunny day\"\n",
610+
"]\n",
611+
"embeddings = bedrock.embed_many(sentences)"
612+
]
613+
},
544614
{
545615
"cell_type": "markdown",
546616
"metadata": {},

poetry.lock

Lines changed: 506 additions & 391 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,15 @@ sentence-transformers = { version = ">=2.2.2", optional = true }
3232
google-cloud-aiplatform = { version = ">=1.26", optional = true }
3333
cohere = { version = ">=4.44", optional = true }
3434
mistralai = { version = ">=0.2.0", optional = true }
35+
boto3 = { version = ">=1.34.0", optional = true }
3536

3637
[tool.poetry.extras]
3738
openai = ["openai"]
3839
sentence-transformers = ["sentence-transformers"]
3940
google_cloud_aiplatform = ["google_cloud_aiplatform"]
4041
cohere = ["cohere"]
4142
mistralai = ["mistralai"]
43+
bedrock = ["boto3"]
4244

4345
[tool.poetry.group.dev.dependencies]
4446
black = ">=20.8b1"

redisvl/utils/vectorize/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers
22
from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer
3+
from redisvl.utils.vectorize.text.bedrock import BedrockTextVectorizer
34
from redisvl.utils.vectorize.text.cohere import CohereTextVectorizer
45
from redisvl.utils.vectorize.text.custom import CustomTextVectorizer
56
from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer
@@ -8,14 +9,15 @@
89
from redisvl.utils.vectorize.text.vertexai import VertexAITextVectorizer
910

1011
__all__ = [
11-
"BaseVectrorizer",
12+
"BaseVectorizer",
1213
"CohereTextVectorizer",
1314
"HFTextVectorizer",
1415
"OpenAITextVectorizer",
1516
"VertexAITextVectorizer",
1617
"AzureOpenAITextVectorizer",
1718
"MistralAITextVectorizer",
1819
"CustomTextVectorizer",
20+
"BedrockTextVectorizer",
1921
]
2022

2123

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
import json
2+
import os
3+
from typing import Any, Callable, Dict, List, Optional
4+
5+
from pydantic.v1 import PrivateAttr
6+
from tenacity import retry, stop_after_attempt, wait_random_exponential
7+
from tenacity.retry import retry_if_not_exception_type
8+
9+
from redisvl.utils.vectorize.base import BaseVectorizer
10+
11+
12+
class BedrockTextVectorizer(BaseVectorizer):
13+
"""The AmazonBedrockTextVectorizer class utilizes Amazon Bedrock's API to generate
14+
embeddings for text data.
15+
16+
This vectorizer is designed to interact with Amazon Bedrock API,
17+
requiring AWS credentials for authentication. The credentials can be provided
18+
directly in the `api_config` dictionary or through environment variables:
19+
- AWS_ACCESS_KEY_ID
20+
- AWS_SECRET_ACCESS_KEY
21+
- AWS_REGION (defaults to us-east-1)
22+
23+
The vectorizer supports synchronous operations with batch processing and
24+
preprocessing capabilities.
25+
26+
.. code-block:: python
27+
28+
# Initialize with explicit credentials
29+
vectorizer = AmazonBedrockTextVectorizer(
30+
model_id="amazon.titan-embed-text-v2:0",
31+
api_config={
32+
"aws_access_key_id": "your_access_key",
33+
"aws_secret_access_key": "your_secret_key",
34+
"region_name": "us-east-1"
35+
}
36+
)
37+
38+
# Initialize using environment variables
39+
vectorizer = AmazonBedrockTextVectorizer()
40+
41+
# Generate embeddings
42+
embedding = vectorizer.embed("Hello, world!")
43+
embeddings = vectorizer.embed_many(["Hello", "World"], batch_size=2)
44+
"""
45+
46+
_client: Any = PrivateAttr()
47+
48+
def __init__(
49+
self,
50+
model_id: str = "amazon.titan-embed-text-v2:0",
51+
api_config: Optional[Dict[str, str]] = None,
52+
) -> None:
53+
"""Initialize the AWS Bedrock Vectorizer.
54+
55+
Args:
56+
model_id (str): The Bedrock model ID to use. Defaults to amazon.titan-embed-text-v2:0
57+
api_config (Optional[Dict[str, str]]): AWS credentials and config.
58+
Can include: aws_access_key_id, aws_secret_access_key, region_name
59+
If not provided, will use environment variables.
60+
61+
Raises:
62+
ValueError: If credentials are not provided in config or environment.
63+
ImportError: If boto3 is not installed.
64+
"""
65+
try:
66+
import boto3 # type: ignore
67+
except ImportError:
68+
raise ImportError(
69+
"Amazon Bedrock vectorizer requires boto3. "
70+
"Please install with `pip install boto3`"
71+
)
72+
73+
if api_config is None:
74+
api_config = {}
75+
76+
aws_access_key_id = api_config.get(
77+
"aws_access_key_id", os.getenv("AWS_ACCESS_KEY_ID")
78+
)
79+
aws_secret_access_key = api_config.get(
80+
"aws_secret_access_key", os.getenv("AWS_SECRET_ACCESS_KEY")
81+
)
82+
region_name = api_config.get(
83+
"region_name", os.getenv("AWS_REGION", "us-east-1")
84+
)
85+
86+
if not aws_access_key_id or not aws_secret_access_key:
87+
raise ValueError(
88+
"AWS credentials required. Provide via api_config or environment variables "
89+
"AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY"
90+
)
91+
92+
self._client = boto3.client(
93+
"bedrock-runtime",
94+
aws_access_key_id=aws_access_key_id,
95+
aws_secret_access_key=aws_secret_access_key,
96+
region_name=region_name,
97+
)
98+
99+
super().__init__(model=model_id, dims=self._set_model_dims(model_id))
100+
101+
def _set_model_dims(self, model_id: str) -> int:
102+
"""Initialize model and determine embedding dimensions."""
103+
try:
104+
response = self._client.invoke_model(
105+
modelId=model_id, body=json.dumps({"inputText": "dimension test"})
106+
)
107+
response_body = json.loads(response["body"].read())
108+
embedding = response_body["embedding"]
109+
return len(embedding)
110+
except Exception as e:
111+
raise ValueError(f"Error initializing Bedrock model: {str(e)}")
112+
113+
@retry(
114+
wait=wait_random_exponential(min=1, max=60),
115+
stop=stop_after_attempt(6),
116+
retry=retry_if_not_exception_type(TypeError),
117+
)
118+
def embed(
119+
self,
120+
text: str,
121+
preprocess: Optional[Callable] = None,
122+
as_buffer: bool = False,
123+
**kwargs,
124+
) -> List[float]:
125+
"""Embed a chunk of text using Amazon Bedrock.
126+
127+
Args:
128+
text (str): Text to embed.
129+
preprocess (Optional[Callable]): Optional preprocessing function.
130+
as_buffer (bool): Whether to return as byte buffer.
131+
132+
Returns:
133+
List[float]: The embedding vector.
134+
135+
Raises:
136+
TypeError: If text is not a string.
137+
"""
138+
if not isinstance(text, str):
139+
raise TypeError("Text must be a string")
140+
141+
if preprocess:
142+
text = preprocess(text)
143+
144+
response = self._client.invoke_model(
145+
modelId=self.model, body=json.dumps({"inputText": text})
146+
)
147+
response_body = json.loads(response["body"].read())
148+
embedding = response_body["embedding"]
149+
150+
dtype = kwargs.pop("dtype", None)
151+
return self._process_embedding(embedding, as_buffer, dtype)
152+
153+
@retry(
154+
wait=wait_random_exponential(min=1, max=60),
155+
stop=stop_after_attempt(6),
156+
retry=retry_if_not_exception_type(TypeError),
157+
)
158+
def embed_many(
159+
self,
160+
texts: List[str],
161+
preprocess: Optional[Callable] = None,
162+
batch_size: int = 10,
163+
as_buffer: bool = False,
164+
**kwargs,
165+
) -> List[List[float]]:
166+
"""Embed multiple texts using Amazon Bedrock.
167+
168+
Args:
169+
texts (List[str]): List of texts to embed.
170+
preprocess (Optional[Callable]): Optional preprocessing function.
171+
batch_size (int): Size of batches for processing.
172+
as_buffer (bool): Whether to return as byte buffers.
173+
174+
Returns:
175+
List[List[float]]: List of embedding vectors.
176+
177+
Raises:
178+
TypeError: If texts is not a list of strings.
179+
"""
180+
if not isinstance(texts, list):
181+
raise TypeError("Texts must be a list of strings")
182+
if texts and not isinstance(texts[0], str):
183+
raise TypeError("Texts must be a list of strings")
184+
185+
embeddings: List[List[float]] = []
186+
dtype = kwargs.pop("dtype", None)
187+
188+
for batch in self.batchify(texts, batch_size, preprocess):
189+
for text in batch:
190+
response = self._client.invoke_model(
191+
modelId=self.model, body=json.dumps({"inputText": text})
192+
)
193+
response_body = json.loads(response["body"].read())
194+
embedding = response_body["embedding"]
195+
embeddings.append(self._process_embedding(embedding, as_buffer, dtype))
196+
197+
return embeddings
198+
199+
@property
200+
def type(self) -> str:
201+
return "bedrock"

0 commit comments

Comments
 (0)