Skip to content

Commit a354220

Browse files
add kwargs support to all vectorizer embed methods
1 parent ad9bb21 commit a354220

File tree

7 files changed

+42
-27
lines changed

7 files changed

+42
-27
lines changed

redisvl/utils/vectorize/text/azureopenai.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ def embed_many(
205205

206206
embeddings: List = []
207207
for batch in self.batchify(texts, batch_size, preprocess):
208-
response = self._client.embeddings.create(input=batch, model=self.model)
208+
response = self._client.embeddings.create(
209+
input=batch, model=self.model, **kwargs
210+
)
209211
embeddings += [
210212
self._process_embedding(r.embedding, as_buffer, dtype)
211213
for r in response.data
@@ -248,7 +250,9 @@ def embed(
248250

249251
dtype = kwargs.pop("dtype", self.dtype)
250252

251-
result = self._client.embeddings.create(input=[text], model=self.model)
253+
result = self._client.embeddings.create(
254+
input=[text], model=self.model, **kwargs
255+
)
252256
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
253257

254258
@retry(
@@ -292,7 +296,7 @@ async def aembed_many(
292296
embeddings: List = []
293297
for batch in self.batchify(texts, batch_size, preprocess):
294298
response = await self._aclient.embeddings.create(
295-
input=batch, model=self.model
299+
input=batch, model=self.model, **kwargs
296300
)
297301
embeddings += [
298302
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -336,7 +340,9 @@ async def aembed(
336340

337341
dtype = kwargs.pop("dtype", self.dtype)
338342

339-
result = await self._aclient.embeddings.create(input=[text], model=self.model)
343+
result = await self._aclient.embeddings.create(
344+
input=[text], model=self.model, **kwargs
345+
)
340346
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
341347

342348
@property

redisvl/utils/vectorize/text/bedrock.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def embed(
156156
text = preprocess(text)
157157

158158
response = self._client.invoke_model(
159-
modelId=self.model, body=json.dumps({"inputText": text})
159+
modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
160160
)
161161
response_body = json.loads(response["body"].read())
162162
embedding = response_body["embedding"]
@@ -206,7 +206,7 @@ def embed_many(
206206
batch_embeddings = []
207207
for text in batch:
208208
response = self._client.invoke_model(
209-
modelId=self.model, body=json.dumps({"inputText": text})
209+
modelId=self.model, body=json.dumps({"inputText": text}), **kwargs
210210
)
211211
response_body = json.loads(response["body"].read())
212212
batch_embeddings.append(response_body["embedding"])

redisvl/utils/vectorize/text/cohere.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def embed(
156156
TypeError: In an invalid input_type is provided.
157157
158158
"""
159-
input_type = kwargs.get("input_type")
159+
input_type = kwargs.pop("input_type", None)
160160

161161
if not isinstance(text, str):
162162
raise TypeError("Must pass in a str value to embed.")
@@ -172,7 +172,7 @@ def embed(
172172
dtype = kwargs.pop("dtype", self.dtype)
173173

174174
embedding = self._client.embed(
175-
texts=[text], model=self.model, input_type=input_type
175+
texts=[text], model=self.model, input_type=input_type, **kwargs
176176
).embeddings[0]
177177
return self._process_embedding(embedding, as_buffer, dtype)
178178

@@ -227,7 +227,7 @@ def embed_many(
227227
TypeError: In an invalid input_type is provided.
228228
229229
"""
230-
input_type = kwargs.get("input_type")
230+
input_type = kwargs.pop("input_type", None)
231231

232232
if not isinstance(texts, list):
233233
raise TypeError("Must pass in a list of str values to embed.")
@@ -244,7 +244,7 @@ def embed_many(
244244
embeddings: List = []
245245
for batch in self.batchify(texts, batch_size, preprocess):
246246
response = self._client.embed(
247-
texts=batch, model=self.model, input_type=input_type
247+
texts=batch, model=self.model, input_type=input_type, **kwargs
248248
)
249249
embeddings += [
250250
self._process_embedding(embedding, as_buffer, dtype)

redisvl/utils/vectorize/text/mistral.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,9 @@ def embed_many(
155155

156156
embeddings: List = []
157157
for batch in self.batchify(texts, batch_size, preprocess):
158-
response = self._client.embeddings.create(model=self.model, inputs=batch)
158+
response = self._client.embeddings.create(
159+
model=self.model, inputs=batch, **kwargs
160+
)
159161
embeddings += [
160162
self._process_embedding(r.embedding, as_buffer, dtype)
161163
for r in response.data
@@ -198,7 +200,9 @@ def embed(
198200

199201
dtype = kwargs.pop("dtype", self.dtype)
200202

201-
result = self._client.embeddings.create(model=self.model, inputs=[text])
203+
result = self._client.embeddings.create(
204+
model=self.model, inputs=[text], **kwargs
205+
)
202206
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
203207

204208
@retry(
@@ -242,7 +246,7 @@ async def aembed_many(
242246
embeddings: List = []
243247
for batch in self.batchify(texts, batch_size, preprocess):
244248
response = await self._client.embeddings.create_async(
245-
model=self.model, inputs=batch
249+
model=self.model, inputs=batch, **kwargs
246250
)
247251
embeddings += [
248252
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -287,7 +291,7 @@ async def aembed(
287291
dtype = kwargs.pop("dtype", self.dtype)
288292

289293
result = await self._client.embeddings.create_async(
290-
model=self.model, inputs=[text]
294+
model=self.model, inputs=[text], **kwargs
291295
)
292296
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
293297

redisvl/utils/vectorize/text/openai.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ def embed_many(
156156

157157
embeddings: List = []
158158
for batch in self.batchify(texts, batch_size, preprocess):
159-
response = self._client.embeddings.create(input=batch, model=self.model)
159+
response = self._client.embeddings.create(
160+
input=batch, model=self.model, **kwargs
161+
)
160162
embeddings += [
161163
self._process_embedding(r.embedding, as_buffer, dtype)
162164
for r in response.data
@@ -199,7 +201,9 @@ def embed(
199201

200202
dtype = kwargs.pop("dtype", self.dtype)
201203

202-
result = self._client.embeddings.create(input=[text], model=self.model)
204+
result = self._client.embeddings.create(
205+
input=[text], model=self.model, **kwargs
206+
)
203207
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
204208

205209
@retry(
@@ -243,7 +247,7 @@ async def aembed_many(
243247
embeddings: List = []
244248
for batch in self.batchify(texts, batch_size, preprocess):
245249
response = await self._aclient.embeddings.create(
246-
input=batch, model=self.model
250+
input=batch, model=self.model, **kwargs
247251
)
248252
embeddings += [
249253
self._process_embedding(r.embedding, as_buffer, dtype)
@@ -287,7 +291,9 @@ async def aembed(
287291

288292
dtype = kwargs.pop("dtype", self.dtype)
289293

290-
result = await self._aclient.embeddings.create(input=[text], model=self.model)
294+
result = await self._aclient.embeddings.create(
295+
input=[text], model=self.model, **kwargs
296+
)
291297
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
292298

293299
@property

redisvl/utils/vectorize/text/vertexai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def embed_many(
168168

169169
embeddings: List = []
170170
for batch in self.batchify(texts, batch_size, preprocess):
171-
response = self._client.get_embeddings(batch)
171+
response = self._client.get_embeddings(batch, **kwargs)
172172
embeddings += [
173173
self._process_embedding(r.values, as_buffer, dtype) for r in response
174174
]
@@ -210,7 +210,7 @@ def embed(
210210

211211
dtype = kwargs.pop("dtype", self.dtype)
212212

213-
result = self._client.get_embeddings([text])
213+
result = self._client.get_embeddings([text], **kwargs)
214214
return self._process_embedding(result[0].values, as_buffer, dtype)
215215

216216
@property

redisvl/utils/vectorize/text/voyageai.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def embed_many(
204204
TypeError: If an invalid input_type is provided.
205205
206206
"""
207-
input_type = kwargs.get("input_type")
208-
truncation = kwargs.get("truncation")
207+
input_type = kwargs.pop("input_type", None)
208+
truncation = kwargs.pop("truncation", None)
209209
dtype = kwargs.pop("dtype", self.dtype)
210210

211211
if not isinstance(texts, list):
@@ -235,7 +235,7 @@ def embed_many(
235235
embeddings: List = []
236236
for batch in self.batchify(texts, batch_size, preprocess):
237237
response = self._client.embed(
238-
texts=batch, model=self.model, input_type=input_type
238+
texts=batch, model=self.model, input_type=input_type, **kwargs
239239
)
240240
embeddings += [
241241
self._process_embedding(embedding, as_buffer, dtype)
@@ -284,8 +284,8 @@ async def aembed_many(
284284
TypeError: In an invalid input_type is provided.
285285
286286
"""
287-
input_type = kwargs.get("input_type")
288-
truncation = kwargs.get("truncation")
287+
input_type = kwargs.pop("input_type", None)
288+
truncation = kwargs.pop("truncation", None)
289289
dtype = kwargs.pop("dtype", self.dtype)
290290

291291
if not isinstance(texts, list):
@@ -315,7 +315,7 @@ async def aembed_many(
315315
embeddings: List = []
316316
for batch in self.batchify(texts, batch_size, preprocess):
317317
response = await self._aclient.embed(
318-
texts=batch, model=self.model, input_type=input_type
318+
texts=batch, model=self.model, input_type=input_type, **kwargs
319319
)
320320
embeddings += [
321321
self._process_embedding(embedding, as_buffer, dtype)
@@ -360,7 +360,6 @@ async def aembed(
360360
Raises:
361361
TypeError: In an invalid input_type is provided.
362362
"""
363-
364363
result = await self.aembed_many(
365364
texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs
366365
)

0 commit comments

Comments
 (0)