Skip to content

Commit e3b5965

Browse files
make extension classes accept vectorizer kwargs
1 parent ab0142c commit e3b5965

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,8 @@ def __init__(
9595
}
9696

9797
# Use the index name as the key prefix by default
98-
if "prefix" in kwargs:
99-
prefix = kwargs["prefix"]
100-
else:
101-
prefix = name
102-
103-
dtype = kwargs.get("dtype")
98+
prefix = kwargs.pop("prefix", name)
99+
dtype = kwargs.pop("dtype", None)
104100

105101
# Validate a provided vectorizer or set the default
106102
if vectorizer:
@@ -111,7 +107,10 @@ def __init__(
111107
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
112108
)
113109
else:
114-
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
110+
vectorizer_kwargs = kwargs
111+
112+
if dtype:
113+
vectorizer_kwargs.update(**{"dtype": dtype})
115114

116115
vectorizer = HFTextVectorizer(
117116
model="sentence-transformers/all-mpnet-base-v2",

redisvl/extensions/router/semantic.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __init__(
7272
connection_kwargs (Dict[str, Any]): The connection arguments
7373
for the redis client. Defaults to empty {}.
7474
"""
75-
dtype = kwargs.get("dtype")
75+
dtype = kwargs.pop("dtype", None)
7676

7777
# Validate a provided vectorizer or set the default
7878
if vectorizer:
@@ -83,8 +83,15 @@ def __init__(
8383
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
8484
)
8585
else:
86-
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
87-
vectorizer = HFTextVectorizer(**vectorizer_kwargs)
86+
vectorizer_kwargs = kwargs
87+
88+
if dtype:
89+
vectorizer_kwargs.update(**{"dtype": dtype})
90+
91+
vectorizer = HFTextVectorizer(
92+
model="sentence-transformers/all-mpnet-base-v2",
93+
**vectorizer_kwargs,
94+
)
8895

8996
if routing_config is None:
9097
routing_config = RoutingConfig()

redisvl/extensions/session_manager/semantic_session.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(
7171
super().__init__(name, session_tag)
7272

7373
prefix = prefix or name
74-
dtype = kwargs.get("dtype")
74+
dtype = kwargs.pop("dtype", None)
7575

7676
# Validate a provided vectorizer or set the default
7777
if vectorizer:
@@ -82,10 +82,13 @@ def __init__(
8282
f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}"
8383
)
8484
else:
85-
vectorizer_kwargs = {"dtype": dtype} if dtype else {}
85+
vectorizer_kwargs = kwargs
86+
87+
if dtype:
88+
vectorizer_kwargs.update(**{"dtype": dtype})
8689

8790
vectorizer = HFTextVectorizer(
88-
model="sentence-transformers/msmarco-distilbert-cos-v5",
91+
model="sentence-transformers/all-mpnet-base-v2",
8992
**vectorizer_kwargs,
9093
)
9194

0 commit comments

Comments
 (0)