Skip to content

Commit ee7e6d2

Browse files
committed
Consider vectorizer the owner of dtype
1 parent ca6396c commit ee7e6d2

File tree

2 files changed

+28
-30
lines changed

2 files changed

+28
-30
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,28 @@ def __init__(
8686
else:
8787
prefix = name
8888

89-
# Set vectorizer default
90-
if vectorizer is None:
89+
# Validate a provided vectorizer or set the default
90+
if vectorizer:
91+
if not isinstance(vectorizer, BaseVectorizer):
92+
raise TypeError("Must provide a valid redisvl.vectorizer class.")
93+
else:
94+
dtype = kwargs.get("dtype")
95+
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
96+
9197
vectorizer = HFTextVectorizer(
92-
model="sentence-transformers/all-mpnet-base-v2"
98+
model="sentence-transformers/all-mpnet-base-v2",
99+
**vectorizer_kwargs,
93100
)
94101

102+
self._vectorizer = vectorizer
103+
104+
# Create semantic cache schema and index
105+
schema = SemanticCacheIndexSchema.from_params(
106+
name, prefix, vectorizer.dims, vectorizer.dtype
107+
)
108+
schema = self._modify_schema(schema, filterable_fields)
109+
self._index = SearchIndex(schema=schema)
110+
95111
# Process fields and other settings
96112
self.set_threshold(distance_threshold)
97113
self.return_fields = [
@@ -103,14 +119,6 @@ def __init__(
103119
METADATA_FIELD_NAME,
104120
]
105121

106-
# Create semantic cache schema and index
107-
dtype = kwargs.get("dtype", "float32")
108-
schema = SemanticCacheIndexSchema.from_params(
109-
name, prefix, vectorizer.dims, dtype
110-
)
111-
schema = self._modify_schema(schema, filterable_fields)
112-
self._index = SearchIndex(schema=schema)
113-
114122
# Handle redis connection
115123
if redis_client:
116124
self._index.set_client(redis_client)
@@ -128,20 +136,9 @@ def __init__(
128136
"If you wish to overwrite the index schema, set overwrite=True during initialization."
129137
)
130138

131-
# Create the search index
139+
# Create the search index in Redis
132140
self._index.create(overwrite=overwrite, drop=False)
133141

134-
# Initialize and validate vectorizer
135-
if not isinstance(vectorizer, BaseVectorizer):
136-
raise TypeError("Must provide a valid redisvl.vectorizer class.")
137-
138-
validate_vector_dims(
139-
vectorizer.dims,
140-
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
141-
)
142-
self._vectorizer = vectorizer
143-
self._dtype = self.index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]
144-
145142
def _modify_schema(
146143
self,
147144
schema: SemanticCacheIndexSchema,
@@ -290,7 +287,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
290287
if not isinstance(prompt, str):
291288
raise TypeError("Prompt must be a string.")
292289

293-
return self._vectorizer.embed(prompt, dtype=self._dtype)
290+
return self._vectorizer.embed(prompt)
294291

295292
async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
296293
"""Converts a text prompt to its vector representation using the
@@ -372,7 +369,7 @@ def check(
372369
num_results=num_results,
373370
return_score=True,
374371
filter_expression=filter_expression,
375-
dtype=self._dtype,
372+
dtype=self._vectorizer.dtype,
376373
)
377374

378375
# Search the cache!
@@ -543,7 +540,7 @@ def store(
543540
# Load cache entry with TTL
544541
ttl = ttl or self._ttl
545542
keys = self._index.load(
546-
data=[cache_entry.to_dict(self._dtype)],
543+
data=[cache_entry.to_dict(self._vectorizer.dtype)],
547544
ttl=ttl,
548545
id_field=ENTRY_ID_FIELD_NAME,
549546
)
@@ -607,7 +604,7 @@ async def astore(
607604
# Load cache entry with TTL
608605
ttl = ttl or self._ttl
609606
keys = await aindex.load(
610-
data=[cache_entry.to_dict(self._dtype)],
607+
data=[cache_entry.to_dict(self._vectorizer.dtype)],
611608
ttl=ttl,
612609
id_field=ENTRY_ID_FIELD_NAME,
613610
)

redisvl/extensions/router/semantic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def __init__(
7474
"""
7575
# Set vectorizer default
7676
if vectorizer is None:
77-
vectorizer = HFTextVectorizer()
77+
dtype = kwargs.get("dtype")
78+
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
79+
vectorizer = HFTextVectorizer(**vectorizer_kwargs)
7880

7981
if routing_config is None:
8082
routing_config = RoutingConfig()
@@ -85,9 +87,8 @@ def __init__(
8587
vectorizer=vectorizer,
8688
routing_config=routing_config,
8789
)
88-
dtype = kwargs.get("dtype", "float32")
8990
self._initialize_index(
90-
redis_client, redis_url, overwrite, dtype, **connection_kwargs
91+
redis_client, redis_url, overwrite, vectorizer.dtype, **connection_kwargs
9192
)
9293

9394
def _initialize_index(

0 commit comments

Comments
 (0)