Skip to content

Commit 269d9e8

Browse files
committed
fix(dependencies): updated LI tests to be optional
1 parent 216546b commit 269d9e8

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

tests/retrieve/test_llama_index_rm.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
1+
import logging
2+
13
import pytest
2-
from llama_index.core import Settings, VectorStoreIndex
3-
from llama_index.core.base.base_retriever import BaseRetriever
4-
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
5-
from llama_index.core.readers.string_iterable import StringIterableReader
64

75
import dspy
86
from dsp.modules.dummy_lm import DummyLM
97
from dspy.datasets import HotPotQA
10-
from dspy.retrieve.llama_index_rm import LlamaIndexRM
8+
9+
try:
10+
from llama_index.core import Settings, VectorStoreIndex
11+
from llama_index.core.base.base_retriever import BaseRetriever
12+
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
13+
from llama_index.core.readers.string_iterable import StringIterableReader
14+
15+
from dspy.retrieve.llama_index_rm import LlamaIndexRM
16+
17+
except ImportError:
18+
logging.info("Optional dependency llama-index is not installed - skipping LlamaIndexRM tests.")
1119

1220

1321
@pytest.fixture()
1422
def rag_setup() -> dict:
1523
"""Builds the necessary fixtures to test LI"""
24+
pytest.importorskip("llamaindex")
1625
dataset = HotPotQA(train_seed=1, train_size=8, eval_seed=2023, dev_size=4, test_size=0)
1726
trainset = [x.with_inputs("question") for x in dataset.train]
1827
devset = [x.with_inputs("question") for x in dataset.dev]
@@ -37,7 +46,7 @@ def rag_setup() -> dict:
3746

3847
def test_lirm_as_rm(rag_setup):
3948
"""Test the retriever as retriever method"""
40-
49+
pytest.importorskip("llamaindex")
4150
retriever = rag_setup.get("retriever")
4251
test_res_li = retriever.retrieve("At My Window was released by which American singer-songwriter?")
4352
rm = rag_setup.get("rm")

0 commit comments

Comments
 (0)