Skip to content

Commit da4b129

Browse files
sets default dtype in vectorizers to float32
1 parent 468ecd4 commit da4b129

File tree

9 files changed

+25
-31
lines changed

9 files changed

+25
-31
lines changed

redisvl/utils/vectorize/base.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,7 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None):
8181
else:
8282
yield seq[pos : pos + size]
8383

84-
def _process_embedding(
85-
self, embedding: List[float], as_buffer: bool, dtype: Optional[str]
86-
):
84+
def _process_embedding(self, embedding: List[float], as_buffer: bool, dtype: str):
8785
if as_buffer:
88-
if not dtype:
89-
raise RuntimeError(
90-
"dtype is required if converting from float to byte string."
91-
)
9286
return array_to_buffer(embedding, dtype)
9387
return embedding

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def embed_many(
190190
if len(texts) > 0 and not isinstance(texts[0], str):
191191
raise TypeError("Must pass in a list of str values to embed.")
192192

193-
dtype = kwargs.pop("dtype", None)
193+
dtype = kwargs.pop("dtype", "float32")
194194

195195
embeddings: List = []
196196
for batch in self.batchify(texts, batch_size, preprocess):
@@ -234,7 +234,7 @@ def embed(
234234
if preprocess:
235235
text = preprocess(text)
236236

237-
dtype = kwargs.pop("dtype", None)
237+
dtype = kwargs.pop("dtype", "float32")
238238

239239
result = self._client.embeddings.create(input=[text], model=self.model)
240240
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@@ -274,7 +274,7 @@ async def aembed_many(
274274
if len(texts) > 0 and not isinstance(texts[0], str):
275275
raise TypeError("Must pass in a list of str values to embed.")
276276

277-
dtype = kwargs.pop("dtype", None)
277+
dtype = kwargs.pop("dtype", "float32")
278278

279279
embeddings: List = []
280280
for batch in self.batchify(texts, batch_size, preprocess):
@@ -320,7 +320,7 @@ async def aembed(
320320
if preprocess:
321321
text = preprocess(text)
322322

323-
dtype = kwargs.pop("dtype", None)
323+
dtype = kwargs.pop("dtype", "float32")
324324

325325
result = await self._aclient.embeddings.create(input=[text], model=self.model)
326326
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def embed(
145145
response_body = json.loads(response["body"].read())
146146
embedding = response_body["embedding"]
147147

148-
dtype = kwargs.pop("dtype", None)
148+
dtype = kwargs.pop("dtype", "float32")
149149
return self._process_embedding(embedding, as_buffer, dtype)
150150

151151
@retry(
@@ -181,7 +181,7 @@ def embed_many(
181181
raise TypeError("Texts must be a list of strings")
182182

183183
embeddings: List[List[float]] = []
184-
dtype = kwargs.pop("dtype", None)
184+
dtype = kwargs.pop("dtype", "float32")
185185

186186
for batch in self.batchify(texts, batch_size, preprocess):
187187
# Process each text in the batch individually since Bedrock

redisvl/utils/vectorize/text/cohere.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def embed(
159159
if preprocess:
160160
text = preprocess(text)
161161

162-
dtype = kwargs.pop("dtype", None)
162+
dtype = kwargs.pop("dtype", "float32")
163163

164164
embedding = self._client.embed(
165165
texts=[text], model=self.model, input_type=input_type
@@ -228,7 +228,7 @@ def embed_many(
228228
See https://docs.cohere.com/reference/embed."
229229
)
230230

231-
dtype = kwargs.pop("dtype", None)
231+
dtype = kwargs.pop("dtype", "float32")
232232

233233
embeddings: List = []
234234
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/custom.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def embed(
173173
if preprocess:
174174
text = preprocess(text)
175175

176-
dtype = kwargs.pop("dtype", None)
176+
dtype = kwargs.pop("dtype", "float32")
177177

178178
result = self._embed_func(text, **kwargs)
179179
return self._process_embedding(result, as_buffer, dtype)
@@ -212,7 +212,7 @@ def embed_many(
212212
if not self._embed_many_func:
213213
raise NotImplementedError
214214

215-
dtype = kwargs.pop("dtype", None)
215+
dtype = kwargs.pop("dtype", "float32")
216216

217217
embeddings: List = []
218218
for batch in self.batchify(texts, batch_size, preprocess):
@@ -254,7 +254,7 @@ async def aembed(
254254
if preprocess:
255255
text = preprocess(text)
256256

257-
dtype = kwargs.pop("dtype", None)
257+
dtype = kwargs.pop("dtype", "float32")
258258

259259
result = await self._aembed_func(text, **kwargs)
260260
return self._process_embedding(result, as_buffer, dtype)
@@ -293,7 +293,7 @@ async def aembed_many(
293293
if not self._aembed_many_func:
294294
raise NotImplementedError
295295

296-
dtype = kwargs.pop("dtype", None)
296+
dtype = kwargs.pop("dtype", "float32")
297297

298298
embeddings: List = []
299299
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/huggingface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def embed(
100100
if preprocess:
101101
text = preprocess(text)
102102

103-
dtype = kwargs.pop("dtype", None)
103+
dtype = kwargs.pop("dtype", "float32")
104104

105105
embedding = self._client.encode([text], **kwargs)[0]
106106
return self._process_embedding(embedding.tolist(), as_buffer, dtype)
@@ -136,7 +136,7 @@ def embed_many(
136136
if len(texts) > 0 and not isinstance(texts[0], str):
137137
raise TypeError("Must pass in a list of str values to embed.")
138138

139-
dtype = kwargs.pop("dtype", None)
139+
dtype = kwargs.pop("dtype", "float32")
140140

141141
embeddings: List = []
142142
for batch in self.batchify(texts, batch_size, preprocess):

redisvl/utils/vectorize/text/mistral.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def embed_many(
140140
if len(texts) > 0 and not isinstance(texts[0], str):
141141
raise TypeError("Must pass in a list of str values to embed.")
142142

143-
dtype = kwargs.pop("dtype", None)
143+
dtype = kwargs.pop("dtype", "float32")
144144

145145
embeddings: List = []
146146
for batch in self.batchify(texts, batch_size, preprocess):
@@ -184,7 +184,7 @@ def embed(
184184
if preprocess:
185185
text = preprocess(text)
186186

187-
dtype = kwargs.pop("dtype", None)
187+
dtype = kwargs.pop("dtype", "float32")
188188

189189
result = self._client.embeddings(model=self.model, input=[text])
190190
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@@ -224,7 +224,7 @@ async def aembed_many(
224224
if len(texts) > 0 and not isinstance(texts[0], str):
225225
raise TypeError("Must pass in a list of str values to embed.")
226226

227-
dtype = kwargs.pop("dtype", None)
227+
dtype = kwargs.pop("dtype", "float32")
228228

229229
embeddings: List = []
230230
for batch in self.batchify(texts, batch_size, preprocess):
@@ -268,7 +268,7 @@ async def aembed(
268268
if preprocess:
269269
text = preprocess(text)
270270

271-
dtype = kwargs.pop("dtype", None)
271+
dtype = kwargs.pop("dtype", "float32")
272272

273273
result = await self._aclient.embeddings(model=self.model, input=[text])
274274
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

redisvl/utils/vectorize/text/openai.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def embed_many(
144144
if len(texts) > 0 and not isinstance(texts[0], str):
145145
raise TypeError("Must pass in a list of str values to embed.")
146146

147-
dtype = kwargs.pop("dtype", None)
147+
dtype = kwargs.pop("dtype", "float32")
148148

149149
embeddings: List = []
150150
for batch in self.batchify(texts, batch_size, preprocess):
@@ -188,7 +188,7 @@ def embed(
188188
if preprocess:
189189
text = preprocess(text)
190190

191-
dtype = kwargs.pop("dtype", None)
191+
dtype = kwargs.pop("dtype", "float32")
192192

193193
result = self._client.embeddings.create(input=[text], model=self.model)
194194
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
@@ -228,7 +228,7 @@ async def aembed_many(
228228
if len(texts) > 0 and not isinstance(texts[0], str):
229229
raise TypeError("Must pass in a list of str values to embed.")
230230

231-
dtype = kwargs.pop("dtype", None)
231+
dtype = kwargs.pop("dtype", "float32")
232232

233233
embeddings: List = []
234234
for batch in self.batchify(texts, batch_size, preprocess):
@@ -274,7 +274,7 @@ async def aembed(
274274
if preprocess:
275275
text = preprocess(text)
276276

277-
dtype = kwargs.pop("dtype", None)
277+
dtype = kwargs.pop("dtype", "float32")
278278

279279
result = await self._aclient.embeddings.create(input=[text], model=self.model)
280280
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)

redisvl/utils/vectorize/text/vertexai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def embed_many(
151151
if len(texts) > 0 and not isinstance(texts[0], str):
152152
raise TypeError("Must pass in a list of str values to embed.")
153153

154-
dtype = kwargs.pop("dtype", None)
154+
dtype = kwargs.pop("dtype", "float32")
155155

156156
embeddings: List = []
157157
for batch in self.batchify(texts, batch_size, preprocess):
@@ -194,7 +194,7 @@ def embed(
194194
if preprocess:
195195
text = preprocess(text)
196196

197-
dtype = kwargs.pop("dtype", None)
197+
dtype = kwargs.pop("dtype", "float32")
198198

199199
result = self._client.get_embeddings([text])
200200
return self._process_embedding(result[0].values, as_buffer, dtype)

0 commit comments

Comments
 (0)