|
| 1 | +import os |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +from redisvl.extensions.cache.embeddings import EmbeddingsCache |
1 | 5 | from redisvl.utils.vectorize.base import BaseVectorizer, Vectorizers |
2 | 6 | from redisvl.utils.vectorize.text.azureopenai import AzureOpenAITextVectorizer |
3 | 7 | from redisvl.utils.vectorize.text.bedrock import BedrockTextVectorizer |
|
23 | 27 | ] |
24 | 28 |
|
25 | 29 |
|
26 | | -def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer: |
| 30 | +def vectorizer_from_dict( |
| 31 | + vectorizer: dict, |
| 32 | + cache: dict = {}, |
| 33 | + cache_folder=os.getenv("SENTENCE_TRANSFORMERS_HOME"), |
| 34 | +) -> BaseVectorizer: |
27 | 35 | vectorizer_type = Vectorizers(vectorizer["type"]) |
28 | 36 | model = vectorizer["model"] |
| 37 | + |
| 38 | + args = {"model": model} |
| 39 | + if cache: |
| 40 | + emb_cache = EmbeddingsCache(**cache) |
| 41 | + args["cache"] = emb_cache |
| 42 | + args["cache_folder"] = cache_folder |
| 43 | + |
29 | 44 | if vectorizer_type == Vectorizers.cohere: |
30 | | - return CohereTextVectorizer(model=model) |
| 45 | + return CohereTextVectorizer(**args) |
31 | 46 | elif vectorizer_type == Vectorizers.openai: |
32 | | - return OpenAITextVectorizer(model=model) |
| 47 | + return OpenAITextVectorizer(**args) |
33 | 48 | elif vectorizer_type == Vectorizers.azure_openai: |
34 | | - return AzureOpenAITextVectorizer(model=model) |
| 49 | + return AzureOpenAITextVectorizer(**args) |
35 | 50 | elif vectorizer_type == Vectorizers.hf: |
36 | | - return HFTextVectorizer(model=model) |
| 51 | + return HFTextVectorizer(**args) |
37 | 52 | elif vectorizer_type == Vectorizers.mistral: |
38 | | - return MistralAITextVectorizer(model=model) |
| 53 | + return MistralAITextVectorizer(**args) |
39 | 54 | elif vectorizer_type == Vectorizers.vertexai: |
40 | | - return VertexAITextVectorizer(model=model) |
| 55 | + return VertexAITextVectorizer(**args) |
41 | 56 | elif vectorizer_type == Vectorizers.voyageai: |
42 | | - return VoyageAITextVectorizer(model=model) |
| 57 | + return VoyageAITextVectorizer(**args) |
43 | 58 | else: |
44 | 59 | raise ValueError(f"Unsupported vectorizer type: {vectorizer_type}") |
0 commit comments