Skip to content

Commit 293d7ac

Browse files
committed
Centralize HF vectorizers into a fixture
1 parent 184f521 commit 293d7ac

File tree

9 files changed

+72
-21
lines changed

9 files changed

+72
-21
lines changed

redisvl/extensions/cache/llm/schema.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ class SemanticCacheIndexSchema(IndexSchema):
114114

115115
@classmethod
116116
def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str):
117-
118117
return cls(
119118
index={"name": name, "prefix": prefix}, # type: ignore
120119
fields=[ # type: ignore

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ def hf_vectorizer():
8080
)
8181

8282

83+
@pytest.fixture(scope="session")
84+
def hf_vectorizer_float16():
85+
return HFTextVectorizer(dtype="float16")
86+
87+
88+
@pytest.fixture(scope="session")
89+
def hf_vectorizer_with_model():
90+
return HFTextVectorizer("sentence-transformers/all-mpnet-base-v2")
91+
92+
8393
@pytest.fixture
8494
def sample_datetimes():
8595
return {

tests/integration/test_aggregation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111

1212
@pytest.fixture
13-
def index(sample_data, redis_url):
13+
def index(sample_data, redis_url, request):
14+
# In xdist, the config has "workerid" in workerinput
15+
workerinput = getattr(request.config, "workerinput", {})
16+
worker_id = workerinput.get("workerid", "master")
17+
1418
index = SearchIndex.from_dict(
1519
{
1620
"index": {
1721
"name": "user_index",
18-
"prefix": "v1",
22+
"prefix": f"v1_{worker_id}",
1923
"storage_type": "hash",
2024
},
2125
"fields": [

tests/integration/test_async_search_index.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,30 @@ def async_index(index_schema, async_client):
3232

3333

3434
@pytest.fixture
35-
def async_index_from_dict():
36-
return AsyncSearchIndex.from_dict({"index": {"name": "my_index"}, "fields": fields})
35+
def async_index_from_dict(request):
36+
# In xdist, the config has "workerid" in workerinput
37+
workerinput = getattr(request.config, "workerinput", {})
38+
worker_id = workerinput.get("workerid", "master")
39+
40+
return AsyncSearchIndex.from_dict(
41+
{"index": {"name": "my_index", "prefix": f"rvl_{worker_id}"}, "fields": fields}
42+
)
3743

3844

3945
@pytest.fixture
40-
def async_index_from_yaml():
41-
return AsyncSearchIndex.from_yaml("schemas/test_json_schema.yaml")
46+
def async_index_from_yaml(request):
47+
# In xdist, the config has "workerid" in workerinput
48+
workerinput = getattr(request.config, "workerinput", {})
49+
worker_id = workerinput.get("workerid", "master")
50+
51+
# Load the schema from YAML
52+
schema = IndexSchema.from_yaml("schemas/test_json_schema.yaml")
53+
54+
# Modify the prefix to include the worker ID
55+
schema.index.prefix = f"{schema.index.prefix}_{worker_id}"
56+
57+
# Create the AsyncSearchIndex with the modified schema
58+
return AsyncSearchIndex(schema=schema)
4259

4360

4461
def test_search_index_properties(index_schema, async_index):

tests/integration/test_llmcache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
@pytest.fixture
19-
def vectorizer():
20-
return HFTextVectorizer("sentence-transformers/all-mpnet-base-v2")
19+
def vectorizer(hf_vectorizer_with_model):
20+
return hf_vectorizer_with_model
2121

2222

2323
@pytest.fixture
@@ -929,12 +929,12 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
929929
)
930930

931931

932-
def test_vectorizer_dtype_mismatch(redis_url):
932+
def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
933933
with pytest.raises(ValueError):
934934
SemanticCache(
935935
name="test_dtype_mismatch",
936936
dtype="float32",
937-
vectorizer=HFTextVectorizer(dtype="float16"),
937+
vectorizer=hf_vectorizer_float16,
938938
redis_url=redis_url,
939939
overwrite=True,
940940
)

tests/integration/test_message_history.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,12 +591,12 @@ def test_bad_dtype_connecting_to_exiting_history(redis_url):
591591
)
592592

593593

594-
def test_vectorizer_dtype_mismatch(redis_url):
594+
def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
595595
with pytest.raises(ValueError):
596596
SemanticMessageHistory(
597597
name="test_dtype_mismatch",
598598
dtype="float32",
599-
vectorizer=HFTextVectorizer(dtype="float16"),
599+
vectorizer=hf_vectorizer_float16,
600600
redis_url=redis_url,
601601
overwrite=True,
602602
)

tests/integration/test_search_index.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,30 @@ def index(index_schema, client):
4242

4343

4444
@pytest.fixture
45-
def index_from_dict():
46-
return SearchIndex.from_dict({"index": {"name": "my_index"}, "fields": fields})
45+
def index_from_dict(request):
46+
# In xdist, the config has "workerid" in workerinput
47+
workerinput = getattr(request.config, "workerinput", {})
48+
worker_id = workerinput.get("workerid", "master")
49+
50+
return SearchIndex.from_dict(
51+
{"index": {"name": "my_index", "prefix": f"rvl_{worker_id}"}, "fields": fields}
52+
)
4753

4854

4955
@pytest.fixture
50-
def index_from_yaml():
51-
return SearchIndex.from_yaml("schemas/test_json_schema.yaml")
56+
def index_from_yaml(request):
57+
# In xdist, the config has "workerid" in workerinput
58+
workerinput = getattr(request.config, "workerinput", {})
59+
worker_id = workerinput.get("workerid", "master")
60+
61+
# Load the schema from YAML
62+
schema = IndexSchema.from_yaml("schemas/test_json_schema.yaml")
63+
64+
# Modify the prefix to include the worker ID
65+
schema.index.prefix = f"{schema.index.prefix}_{worker_id}"
66+
67+
# Create the SearchIndex with the modified schema
68+
return SearchIndex(schema=schema)
5269

5370

5471
def test_search_index_properties(index_schema, index):

tests/integration/test_search_results.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ def filter_query():
1616

1717

1818
@pytest.fixture
19-
def index(sample_data, redis_url):
19+
def index(sample_data, redis_url, request):
20+
# In xdist, the config has "workerid" in workerinput
21+
workerinput = getattr(request.config, "workerinput", {})
22+
worker_id = workerinput.get("workerid", "master")
23+
2024
fields_spec = [
2125
{"name": "credit_score", "type": "tag"},
2226
{"name": "user", "type": "tag"},
@@ -37,7 +41,7 @@ def index(sample_data, redis_url):
3741
json_schema = {
3842
"index": {
3943
"name": "user_index_json",
40-
"prefix": "users_json",
44+
"prefix": f"users_json_{worker_id}",
4145
"storage_type": "json",
4246
},
4347
"fields": fields_spec,

tests/integration/test_semantic_router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,13 @@ def test_bad_dtype_connecting_to_exiting_router(redis_url, routes):
325325
)
326326

327327

328-
def test_vectorizer_dtype_mismatch(routes, redis_url):
328+
def test_vectorizer_dtype_mismatch(routes, redis_url, hf_vectorizer_float16):
329329
with pytest.raises(ValueError):
330330
SemanticRouter(
331331
name="test_dtype_mismatch",
332332
routes=routes,
333333
dtype="float32",
334-
vectorizer=HFTextVectorizer(dtype="float16"),
334+
vectorizer=hf_vectorizer_float16,
335335
redis_url=redis_url,
336336
overwrite=True,
337337
)

0 commit comments

Comments
 (0)