Skip to content

Commit 50e13a7

Browse files
clean up makefile and scripts
1 parent 44f4ba8 commit 50e13a7

File tree

6 files changed

+66
-144
lines changed

6 files changed

+66
-144
lines changed

.github/workflows/test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ jobs:
8383
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
8484
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
8585
run: |
86-
poetry run test-verbose
86+
make test-all
8787
8888
- name: Run tests
8989
if: matrix.connection != 'plain' || matrix.redis-stack-version != 'latest'
9090
run: |
91-
SKIP_VECTORIZERS=True SKIP_RERANKERS=True poetry run test-verbose
91+
make test
9292
9393
- name: Run notebooks
9494
if: matrix.connection == 'plain' && matrix.redis-stack-version == 'latest'
@@ -106,7 +106,7 @@ jobs:
106106
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
107107
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
108108
run: |
109-
cd docs/ && poetry run pytest --nbval-lax ./user_guide -vv
109+
make test-notebooks
110110
111111
docs:
112112
runs-on: ubuntu-latest

Makefile

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
.PHONY: install format lint test test-all clean redis-start redis-stop check-types docs-build docs-serve check
1+
.PHONY: install format lint test test-all test-notebooks clean redis-start redis-stop check-types docs-build docs-serve check
22

33
install:
44
poetry install --all-extras
@@ -19,10 +19,13 @@ check-types:
1919
lint: format check-types
2020

2121
test:
22-
SKIP_RERANKERS=true SKIP_VECTORIZERS=true poetry run test-verbose
22+
poetry run test-verbose
2323

2424
test-all:
25-
poetry run test-verbose
25+
poetry run test-verbose --run-api-tests
26+
27+
test-notebooks:
28+
poetry run test-notebooks
2629

2730
check: lint test
2831

scripts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_verbose():
4848

4949

5050
def test_notebooks():
51-
subprocess.run(["cd", "docs/", "&&", "poetry run treon", "-v"], check=True)
51+
subprocess.run(["cd", "docs/", "&&", "poetry run pytest --nbval-lax ./user_guide", "-vv"], check=True)
5252

5353

5454
def build_docs():

tests/conftest.py

Lines changed: 2 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,6 @@ async def async_client(redis_url):
5858
redis_url
5959
) as client:
6060
yield client
61-
# try:
62-
# await client.aclose()
63-
# except RuntimeError as e:
64-
# if "Event loop is closed" not in str(e):
65-
# raise
6661

6762

6863
@pytest.fixture
@@ -72,51 +67,6 @@ def client(redis_url):
7267
"""
7368
conn = RedisConnectionFactory.get_redis_connection(redis_url)
7469
yield conn
75-
# conn.close()
76-
77-
78-
# @pytest.fixture
79-
# def openai_key():
80-
# return os.getenv("OPENAI_API_KEY")
81-
82-
83-
# @pytest.fixture
84-
# def openai_version():
85-
# return os.getenv("OPENAI_API_VERSION")
86-
87-
88-
# @pytest.fixture
89-
# def azure_endpoint():
90-
# return os.getenv("AZURE_OPENAI_ENDPOINT")
91-
92-
93-
# @pytest.fixture
94-
# def cohere_key():
95-
# return os.getenv("COHERE_API_KEY")
96-
97-
98-
# @pytest.fixture
99-
# def mistral_key():
100-
# return os.getenv("MISTRAL_API_KEY")
101-
102-
103-
# @pytest.fixture
104-
# def gcp_location():
105-
# return os.getenv("GCP_LOCATION")
106-
107-
108-
# @pytest.fixture
109-
# def gcp_project_id():
110-
# return os.getenv("GCP_PROJECT_ID")
111-
112-
113-
# @pytest.fixture
114-
# def aws_credentials():
115-
# return {
116-
# "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
117-
# "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY"),
118-
# "aws_region": os.getenv("AWS_REGION", "us-east-1"),
119-
# }
12070

12171

12272
@pytest.fixture
@@ -201,6 +151,8 @@ def pytest_collection_modifyitems(
201151
) -> None:
202152
if config.getoption("--run-api-tests"):
203153
return
154+
155+
# Otherwise skip all tests requiring an API key
204156
skip_api = pytest.mark.skip(
205157
reason="Skipping test because API keys are not provided. Use --run-api-tests to run these tests."
206158
)

tests/integration/test_rerankers.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,14 @@
99
)
1010

1111

12-
@pytest.fixture
13-
def skip_reranker() -> bool:
14-
# os.getenv returns a string
15-
v = os.getenv("SKIP_RERANKERS", "False").lower() == "true"
16-
return v
17-
18-
1912
# Fixture for the reranker instance
2013
@pytest.fixture(
2114
params=[
2215
CohereReranker,
2316
VoyageAIReranker,
2417
]
2518
)
26-
def reranker(request, skip_reranker):
27-
if skip_reranker:
28-
pytest.skip("Skipping reranker instantiation...")
29-
19+
def reranker(request):
3020
if request.param == CohereReranker:
3121
return request.param()
3222
elif request.param == VoyageAIReranker:
@@ -43,7 +33,7 @@ def hfCrossEncoderRerankerWithCustomModel():
4333
return HFCrossEncoderReranker("cross-encoder/stsb-distilroberta-base")
4434

4535

46-
# Test for basic ranking functionality
36+
@pytest.mark.requires_api_keys
4737
def test_rank_documents(reranker):
4838
docs = ["document one", "document two", "document three"]
4939
query = "search query"
@@ -55,7 +45,7 @@ def test_rank_documents(reranker):
5545
assert all(isinstance(score, float) for score in scores) # Scores should be floats
5646

5747

58-
# Test for asynchronous ranking functionality
48+
@pytest.mark.requires_api_keys
5949
@pytest.mark.asyncio
6050
async def test_async_rank_documents(reranker):
6151
docs = ["document one", "document two", "document three"]
@@ -68,7 +58,7 @@ async def test_async_rank_documents(reranker):
6858
assert all(isinstance(score, float) for score in scores) # Scores should be floats
6959

7060

71-
# Test handling of bad input
61+
@pytest.mark.requires_api_keys
7262
def test_bad_input(reranker):
7363
with pytest.raises(Exception):
7464
reranker.rank("", []) # Empty query or documents

tests/integration/test_vectorizers.py

Lines changed: 50 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
)
1616

1717

18-
@pytest.fixture
19-
def skip_vectorizer() -> bool:
20-
v = os.getenv("SKIP_VECTORIZERS", "False").lower() == "true"
21-
return v
22-
23-
2418
@pytest.fixture(
2519
params=[
2620
HFTextVectorizer,
@@ -34,10 +28,7 @@ def skip_vectorizer() -> bool:
3428
VoyageAITextVectorizer,
3529
]
3630
)
37-
def vectorizer(request, skip_vectorizer):
38-
if skip_vectorizer:
39-
pytest.skip("Skipping vectorizer instantiation...")
40-
31+
def vectorizer(request):
4132
if request.param == HFTextVectorizer:
4233
return request.param()
4334
elif request.param == OpenAITextVectorizer:
@@ -70,10 +61,7 @@ def embed_many(texts):
7061

7162

7263
@pytest.fixture
73-
def bedrock_vectorizer(skip_vectorizer):
74-
if skip_vectorizer:
75-
pytest.skip("Skipping Bedrock vectorizer tests...")
76-
64+
def bedrock_vectorizer():
7765
return BedrockTextVectorizer(
7866
model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0")
7967
)
@@ -108,6 +96,7 @@ def embed_many_with_args(self, texts, param=True):
10896
return MyEmbedder
10997

11098

99+
@pytest.mark.requires_api_keys
111100
def test_vectorizer_embed(vectorizer):
112101
text = "This is a test sentence."
113102
if isinstance(vectorizer, CohereTextVectorizer):
@@ -121,6 +110,7 @@ def test_vectorizer_embed(vectorizer):
121110
assert len(embedding) == vectorizer.dims
122111

123112

113+
@pytest.mark.requires_api_keys
124114
def test_vectorizer_embed_many(vectorizer):
125115
texts = ["This is the first test sentence.", "This is the second test sentence."]
126116
if isinstance(vectorizer, CohereTextVectorizer):
@@ -137,6 +127,7 @@ def test_vectorizer_embed_many(vectorizer):
137127
)
138128

139129

130+
@pytest.mark.requires_api_keys
140131
def test_vectorizer_bad_input(vectorizer):
141132
with pytest.raises(TypeError):
142133
vectorizer.embed(1)
@@ -148,6 +139,7 @@ def test_vectorizer_bad_input(vectorizer):
148139
vectorizer.embed_many(42)
149140

150141

142+
@pytest.mark.requires_api_keys
151143
def test_bedrock_bad_credentials():
152144
with pytest.raises(ValueError):
153145
BedrockTextVectorizer(
@@ -158,6 +150,7 @@ def test_bedrock_bad_credentials():
158150
)
159151

160152

153+
@pytest.mark.requires_api_keys
161154
def test_bedrock_invalid_model(bedrock_vectorizer):
162155
with pytest.raises(ValueError):
163156
bedrock = BedrockTextVectorizer(model="invalid-model")
@@ -250,64 +243,48 @@ def bad_return_type(text: str) -> str:
250243
)
251244

252245

253-
@pytest.mark.parametrize(
254-
"vector_class",
255-
[
256-
AzureOpenAITextVectorizer,
257-
BedrockTextVectorizer,
258-
CohereTextVectorizer,
259-
CustomTextVectorizer,
260-
HFTextVectorizer,
261-
MistralAITextVectorizer,
262-
OpenAITextVectorizer,
263-
VertexAITextVectorizer,
264-
VoyageAITextVectorizer,
265-
],
266-
)
267-
def test_dtypes(vector_class, skip_vectorizer):
268-
if skip_vectorizer:
269-
pytest.skip("Skipping vectorizer instantiation...")
270-
271-
# test dtype defaults to float32
272-
if issubclass(vector_class, CustomTextVectorizer):
273-
vectorizer = vector_class(embed=lambda x, input_type=None: [1.0, 2.0, 3.0])
274-
elif issubclass(vector_class, AzureOpenAITextVectorizer):
275-
vectorizer = vector_class(
276-
model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002")
277-
)
278-
else:
279-
vectorizer = vector_class()
280-
281-
assert vectorizer.dtype == "float32"
282-
283-
# test initializing dtype in constructor
284-
for dtype in ["float16", "float32", "float64", "bfloat16"]:
285-
if issubclass(vector_class, CustomTextVectorizer):
286-
vectorizer = vector_class(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype)
287-
elif issubclass(vector_class, AzureOpenAITextVectorizer):
288-
vectorizer = vector_class(
289-
model=os.getenv(
290-
"AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
291-
),
292-
dtype=dtype,
293-
)
294-
else:
295-
vectorizer = vector_class(dtype=dtype)
296-
297-
assert vectorizer.dtype == dtype
298-
299-
# test validation of dtype on init
300-
if issubclass(vector_class, CustomTextVectorizer):
301-
pytest.skip("skipping custom text vectorizer")
246+
# @pytest.mark.requires_api_keys
247+
# def test_dtypes(vectorizer):
248+
# # # test dtype defaults to float32
249+
# # if issubclass(vectorizer, CustomTextVectorizer):
250+
# # vectorizer = vectorizer(embed=lambda x, input_type=None: [1.0, 2.0, 3.0])
251+
# # elif issubclass(vectorizer, AzureOpenAITextVectorizer):
252+
# # vectorizer = vectorizer(
253+
# # model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002")
254+
# # )
255+
# # else:
256+
# # vectorizer = vector_class()
302257

303-
with pytest.raises(ValueError):
304-
vectorizer = vector_class(dtype="float25")
258+
# assert vectorizer.dtype == "float32"
305259

306-
with pytest.raises(ValueError):
307-
vectorizer = vector_class(dtype=7)
260+
# # test initializing dtype in constructor
261+
# for dtype in ["float16", "float32", "float64", "bfloat16"]:
262+
# if issubclass(vectorizer, CustomTextVectorizer):
263+
# vectorizer = vectorizer(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype)
264+
# elif issubclass(vectorizer, AzureOpenAITextVectorizer):
265+
# vectorizer = vectorizer(
266+
# model=os.getenv(
267+
# "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002"
268+
# ),
269+
# dtype=dtype,
270+
# )
271+
# else:
272+
# vectorizer = vectorizer(dtype=dtype)
308273

309-
with pytest.raises(ValueError):
310-
vectorizer = vector_class(dtype=None)
274+
# assert vectorizer.dtype == dtype
275+
276+
# # test validation of dtype on init
277+
# if issubclass(vectorizer, CustomTextVectorizer):
278+
# pytest.skip("skipping custom text vectorizer")
279+
280+
# with pytest.raises(ValueError):
281+
# vectorizer = vectorizer(dtype="float25")
282+
283+
# with pytest.raises(ValueError):
284+
# vectorizer = vectorizer(dtype=7)
285+
286+
# with pytest.raises(ValueError):
287+
# vectorizer = vectorizer(dtype=None)
311288

312289

313290
@pytest.fixture(
@@ -319,10 +296,7 @@ def test_dtypes(vector_class, skip_vectorizer):
319296
VoyageAITextVectorizer,
320297
]
321298
)
322-
def avectorizer(request, skip_vectorizer):
323-
if skip_vectorizer:
324-
pytest.skip("Skipping vectorizer instantiation...")
325-
299+
def avectorizer(request):
326300
if request.param == CustomTextVectorizer:
327301

328302
def embed_func(text):
@@ -341,6 +315,7 @@ async def aembed_many_func(texts):
341315
return request.param()
342316

343317

318+
@pytest.mark.requires_api_keys
344319
@pytest.mark.asyncio
345320
async def test_vectorizer_aembed(avectorizer):
346321
text = "This is a test sentence."
@@ -350,6 +325,7 @@ async def test_vectorizer_aembed(avectorizer):
350325
assert len(embedding) == avectorizer.dims
351326

352327

328+
@pytest.mark.requires_api_keys
353329
@pytest.mark.asyncio
354330
async def test_vectorizer_aembed_many(avectorizer):
355331
texts = ["This is the first test sentence.", "This is the second test sentence."]
@@ -362,6 +338,7 @@ async def test_vectorizer_aembed_many(avectorizer):
362338
)
363339

364340

341+
@pytest.mark.requires_api_keys
365342
@pytest.mark.asyncio
366343
async def test_avectorizer_bad_input(avectorizer):
367344
with pytest.raises(TypeError):

0 commit comments

Comments
 (0)