Skip to content

Commit 05b1da8

Browse files
authored
Merge pull request #1638 from chenmoneygithub/fix-saving-arg
Fix saving/loading retriever
2 parents a1eae3f + d5e80ef commit 05b1da8

File tree

2 files changed

+36
-9
lines changed

2 files changed

+36
-9
lines changed

dspy/primitives/module.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from
99
# named_sub_modules for the time being.
1010

11+
1112
class BaseModule:
1213
def __init__(self):
1314
pass
@@ -29,7 +30,7 @@ def add_parameter(param_name, param_value):
2930
visited.add(id(param_value))
3031
param_name = postprocess_parameter_name(param_name, param_value)
3132
named_parameters.append((param_name, param_value))
32-
33+
3334
elif isinstance(param_value, dspy.Module):
3435
# When a sub-module is pre-compiled, keep it frozen.
3536
if not getattr(param_value, "_compiled", False):
@@ -42,7 +43,7 @@ def add_parameter(param_name, param_value):
4243
for name, value in self.__dict__.items():
4344
if isinstance(value, Parameter):
4445
add_parameter(name, value)
45-
46+
4647
elif isinstance(value, dspy.Module):
4748
# When a sub-module is pre-compiled, keep it frozen.
4849
if not getattr(value, "_compiled", False):
@@ -153,7 +154,11 @@ def dump_state(self, save_verbose):
153154

154155
def load_state(self, state, use_legacy_loading=False):
155156
for name, param in self.named_parameters():
156-
param.load_state(state[name], use_legacy_loading=use_legacy_loading)
157+
if isinstance(param, BaseModule):
158+
param.load_state(state[name], use_legacy_loading=use_legacy_loading)
159+
else:
160+
# `use_legacy_loading` is only applicable for BaseModule instances.
161+
param.load_state(state[name])
157162

158163
def save(self, path, save_field_meta=False):
159164
with open(path, "w") as f:
@@ -168,11 +173,11 @@ def postprocess_parameter_name(name, value):
168173
# For ChainOfThought backward compatibility, remove ending ._predict if it's there
169174
if name.endswith("._predict"):
170175
name = name[:-9]
171-
176+
172177
if name.endswith(".self"):
173178
name = name[:-5]
174-
179+
175180
if name == "_predict":
176181
return "self"
177-
178-
return name
182+
183+
return name

tests/retrieve/test_llama_index_rm.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
@pytest.fixture()
2222
def rag_setup() -> dict:
2323
"""Builds the necessary fixtures to test LI"""
24-
pytest.importorskip("llamaindex")
24+
pytest.importorskip("llama_index")
2525
dataset = HotPotQA(train_seed=1, train_size=8, eval_seed=2023, dev_size=4, test_size=0)
2626
trainset = [x.with_inputs("question") for x in dataset.train]
2727
devset = [x.with_inputs("question") for x in dataset.dev]
@@ -46,7 +46,7 @@ def rag_setup() -> dict:
4646

4747
def test_lirm_as_rm(rag_setup):
4848
"""Test the retriever as retriever method"""
49-
pytest.importorskip("llamaindex")
49+
pytest.importorskip("llama_index")
5050
retriever = rag_setup.get("retriever")
5151
test_res_li = retriever.retrieve("At My Window was released by which American singer-songwriter?")
5252
rm = rag_setup.get("rm")
@@ -59,3 +59,25 @@ def test_lirm_as_rm(rag_setup):
5959
assert isinstance(test_res_dspy, list), "Ensuring the results are a list from the DSPy retriever"
6060

6161
assert len(test_res_li) == len(test_res_dspy), "Rough equality check of the results"
62+
63+
64+
def test_save_load_llama_index_rag(rag_setup, tmp_path):
65+
pytest.importorskip("llama_index")
66+
67+
class RAG(dspy.Module):
68+
def __init__(self):
69+
super().__init__()
70+
self.retriever = dspy.Retrieve(k=3)
71+
self.cot = dspy.ChainOfThought("question, context -> answer")
72+
73+
rag = RAG()
74+
rag.retriever.k = 4
75+
76+
file_path = tmp_path / "rag.json"
77+
rag.save(file_path)
78+
loaded_rag = RAG()
79+
# Before loading, the retriever k should be 3.
80+
assert loaded_rag.retriever.k == 3
81+
# After loading, the retriever k should be 4.
82+
loaded_rag.load(file_path)
83+
assert loaded_rag.retriever.k == 4

0 commit comments

Comments
 (0)