Skip to content

Commit 1beb182

Browse files
committed
add max_iter arg in evaluate script
Signed-off-by: ChengZi <chen.zhang@zilliz.com>
1 parent ebf45d1 commit 1beb182

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

evaluation/evaluate.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@
3232

3333

3434
def _deepsearch_retrieve_titles(
35-
question: str, retry_num: int = 4, base_wait_time: int = 4
35+
question: str,
36+
retry_num: int = 4,
37+
base_wait_time: int = 4,
38+
max_iter: int = 3,
3639
) -> Tuple[List[str], int, bool]:
3740
retrieved_results = []
3841
consume_tokens = 0
3942
for i in range(retry_num):
4043
try:
41-
retrieved_results, _, consume_tokens = retrieve(question)
44+
retrieved_results, _, consume_tokens = retrieve(question, max_iter=max_iter)
4245
break
4346
except Exception:
4447
wait_time = base_wait_time * (2**i)
@@ -91,6 +94,7 @@ def evaluate(
9194
dataset: str,
9295
output_root: str,
9396
pre_num: int = 10,
97+
max_iter: int = 3,
9498
skip_load=False,
9599
flag: str = "result",
96100
):
@@ -134,7 +138,9 @@ def evaluate(
134138
global_idx = sample_idx + start_ind
135139
question = sample["question"]
136140

137-
retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles(question)
141+
retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles(
142+
question, max_iter=max_iter
143+
)
138144
retrieved_titles_naive = _naive_retrieve_titles(question)
139145

140146
if fail:
@@ -206,6 +212,12 @@ def main_eval():
206212
default=30,
207213
help="Number of samples to evaluate, default is 30",
208214
)
215+
parser.add_argument(
216+
"--max_iter",
217+
type=int,
218+
default=3,
219+
help="Max iterations of reflection. Default is 3. It will overwrite the one in config yaml file.",
220+
)
209221
parser.add_argument(
210222
"--output_dir",
211223
type=str,
@@ -233,6 +245,7 @@ def main_eval():
233245
dataset=args.dataset,
234246
output_root=args.output_dir,
235247
pre_num=args.pre_num,
248+
max_iter=args.max_iter,
236249
skip_load=args.skip_load,
237250
flag=args.flag,
238251
)

0 commit comments

Comments
 (0)