Skip to content

Commit eea6266

Browse files
committed
First pass at central HF fixtures
1 parent fb8ab19 commit eea6266

File tree

9 files changed

+73
-26
lines changed

9 files changed

+73
-26
lines changed

redisvl/extensions/cache/llm/schema.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111
)
1212
from redisvl.redis.utils import array_to_buffer
1313
from redisvl.schema import IndexSchema
14-
from redisvl.utils.utils import ( # hashify is from utils.utils, not redis.utils
15-
current_timestamp,
16-
deserialize,
17-
hashify,
18-
serialize,
19-
)
14+
from redisvl.utils.utils import current_timestamp, deserialize, hashify, serialize
2015

2116

2217
class CacheEntry(BaseModel):

tests/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,16 @@ def hf_vectorizer():
8080
)
8181

8282

83+
@pytest.fixture(scope="session")
84+
def hf_vectorizer_float16():
85+
return HFTextVectorizer(dtype="float16")
86+
87+
88+
@pytest.fixture(scope="session")
89+
def hf_vectorizer_with_model():
90+
return HFTextVectorizer("sentence-transformers/all-mpnet-base-v2")
91+
92+
8393
@pytest.fixture
8494
def sample_datetimes():
8595
return {

tests/integration/test_aggregation.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010

1111

1212
@pytest.fixture
13-
def index(sample_data, redis_url):
13+
def index(sample_data, redis_url, request):
14+
# In xdist, the config has "workerid" in workerinput
15+
workerinput = getattr(request.config, "workerinput", {})
16+
worker_id = workerinput.get("workerid", "master")
17+
1418
index = SearchIndex.from_dict(
1519
{
1620
"index": {
1721
"name": "user_index",
18-
"prefix": "v1",
22+
"prefix": f"v1_{worker_id}",
1923
"storage_type": "hash",
2024
},
2125
"fields": [

tests/integration/test_async_search_index.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,30 @@ def async_index(index_schema, async_client):
3232

3333

3434
@pytest.fixture
35-
def async_index_from_dict():
36-
return AsyncSearchIndex.from_dict({"index": {"name": "my_index"}, "fields": fields})
35+
def async_index_from_dict(request):
36+
# In xdist, the config has "workerid" in workerinput
37+
workerinput = getattr(request.config, "workerinput", {})
38+
worker_id = workerinput.get("workerid", "master")
39+
40+
return AsyncSearchIndex.from_dict(
41+
{"index": {"name": "my_index", "prefix": f"rvl_{worker_id}"}, "fields": fields}
42+
)
3743

3844

3945
@pytest.fixture
40-
def async_index_from_yaml():
41-
return AsyncSearchIndex.from_yaml("schemas/test_json_schema.yaml")
46+
def async_index_from_yaml(request):
47+
# In xdist, the config has "workerid" in workerinput
48+
workerinput = getattr(request.config, "workerinput", {})
49+
worker_id = workerinput.get("workerid", "master")
50+
51+
# Load the schema from YAML
52+
schema = IndexSchema.from_yaml("schemas/test_json_schema.yaml")
53+
54+
# Modify the prefix to include the worker ID
55+
schema.index.prefix = f"{schema.index.prefix}_{worker_id}"
56+
57+
# Create the AsyncSearchIndex with the modified schema
58+
return AsyncSearchIndex(schema=schema)
4259

4360

4461
def test_search_index_properties(index_schema, async_index):

tests/integration/test_llmcache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717

1818
@pytest.fixture
19-
def vectorizer():
20-
return HFTextVectorizer("sentence-transformers/all-mpnet-base-v2")
19+
def vectorizer(hf_vectorizer_with_model):
20+
return hf_vectorizer_with_model
2121

2222

2323
@pytest.fixture
@@ -929,12 +929,12 @@ def test_bad_dtype_connecting_to_existing_cache(redis_url):
929929
)
930930

931931

932-
def test_vectorizer_dtype_mismatch(redis_url):
932+
def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
933933
with pytest.raises(ValueError):
934934
SemanticCache(
935935
name="test_dtype_mismatch",
936936
dtype="float32",
937-
vectorizer=HFTextVectorizer(dtype="float16"),
937+
vectorizer=hf_vectorizer_float16,
938938
redis_url=redis_url,
939939
overwrite=True,
940940
)

tests/integration/test_search_index.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,30 @@ def index(index_schema, client):
4242

4343

4444
@pytest.fixture
45-
def index_from_dict():
46-
return SearchIndex.from_dict({"index": {"name": "my_index"}, "fields": fields})
45+
def index_from_dict(request):
46+
# In xdist, the config has "workerid" in workerinput
47+
workerinput = getattr(request.config, "workerinput", {})
48+
worker_id = workerinput.get("workerid", "master")
49+
50+
return SearchIndex.from_dict(
51+
{"index": {"name": "my_index", "prefix": f"rvl_{worker_id}"}, "fields": fields}
52+
)
4753

4854

4955
@pytest.fixture
50-
def index_from_yaml():
51-
return SearchIndex.from_yaml("schemas/test_json_schema.yaml")
56+
def index_from_yaml(request):
57+
# In xdist, the config has "workerid" in workerinput
58+
workerinput = getattr(request.config, "workerinput", {})
59+
worker_id = workerinput.get("workerid", "master")
60+
61+
# Load the schema from YAML
62+
schema = IndexSchema.from_yaml("schemas/test_json_schema.yaml")
63+
64+
# Modify the prefix to include the worker ID
65+
schema.index.prefix = f"{schema.index.prefix}_{worker_id}"
66+
67+
# Create the SearchIndex with the modified schema
68+
return SearchIndex(schema=schema)
5269

5370

5471
def test_search_index_properties(index_schema, index):

tests/integration/test_search_results.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@ def filter_query():
1616

1717

1818
@pytest.fixture
19-
def index(sample_data, redis_url):
19+
def index(sample_data, redis_url, request):
20+
# In xdist, the config has "workerid" in workerinput
21+
workerinput = getattr(request.config, "workerinput", {})
22+
worker_id = workerinput.get("workerid", "master")
23+
2024
fields_spec = [
2125
{"name": "credit_score", "type": "tag"},
2226
{"name": "user", "type": "tag"},
@@ -37,7 +41,7 @@ def index(sample_data, redis_url):
3741
json_schema = {
3842
"index": {
3943
"name": "user_index_json",
40-
"prefix": "users_json",
44+
"prefix": f"users_json_{worker_id}",
4145
"storage_type": "json",
4246
},
4347
"fields": fields_spec,

tests/integration/test_semantic_router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,13 +320,13 @@ def test_bad_dtype_connecting_to_exiting_router(redis_url, routes):
320320
)
321321

322322

323-
def test_vectorizer_dtype_mismatch(routes, redis_url):
323+
def test_vectorizer_dtype_mismatch(routes, redis_url, hf_vectorizer_float16):
324324
with pytest.raises(ValueError):
325325
SemanticRouter(
326326
name="test_dtype_mismatch",
327327
routes=routes,
328328
dtype="float32",
329-
vectorizer=HFTextVectorizer(dtype="float16"),
329+
vectorizer=hf_vectorizer_float16,
330330
redis_url=redis_url,
331331
overwrite=True,
332332
)

tests/integration/test_session_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -594,12 +594,12 @@ def test_bad_dtype_connecting_to_exiting_session(redis_url):
594594
)
595595

596596

597-
def test_vectorizer_dtype_mismatch(redis_url):
597+
def test_vectorizer_dtype_mismatch(redis_url, hf_vectorizer_float16):
598598
with pytest.raises(ValueError):
599599
SemanticSessionManager(
600600
name="test_dtype_mismatch",
601601
dtype="float32",
602-
vectorizer=HFTextVectorizer(dtype="float16"),
602+
vectorizer=hf_vectorizer_float16,
603603
redis_url=redis_url,
604604
overwrite=True,
605605
)

0 commit comments

Comments
 (0)