|
32 | 32 |
|
33 | 33 |
|
34 | 34 | def _deepsearch_retrieve_titles( |
35 | | - question: str, retry_num: int = 4, base_wait_time: int = 4 |
| 35 | + question: str, |
| 36 | + retry_num: int = 4, |
| 37 | + base_wait_time: int = 4, |
| 38 | + max_iter: int = 3, |
36 | 39 | ) -> Tuple[List[str], int, bool]: |
37 | 40 | retrieved_results = [] |
38 | 41 | consume_tokens = 0 |
39 | 42 | for i in range(retry_num): |
40 | 43 | try: |
41 | | - retrieved_results, _, consume_tokens = retrieve(question) |
| 44 | + retrieved_results, _, consume_tokens = retrieve(question, max_iter=max_iter) |
42 | 45 | break |
43 | 46 | except Exception: |
44 | 47 | wait_time = base_wait_time * (2**i) |
@@ -91,6 +94,7 @@ def evaluate( |
91 | 94 | dataset: str, |
92 | 95 | output_root: str, |
93 | 96 | pre_num: int = 10, |
| 97 | + max_iter: int = 3, |
94 | 98 | skip_load=False, |
95 | 99 | flag: str = "result", |
96 | 100 | ): |
@@ -134,7 +138,9 @@ def evaluate( |
134 | 138 | global_idx = sample_idx + start_ind |
135 | 139 | question = sample["question"] |
136 | 140 |
|
137 | | - retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles(question) |
| 141 | + retrieved_titles, consume_tokens, fail = _deepsearch_retrieve_titles( |
| 142 | + question, max_iter=max_iter |
| 143 | + ) |
138 | 144 | retrieved_titles_naive = _naive_retrieve_titles(question) |
139 | 145 |
|
140 | 146 | if fail: |
@@ -206,6 +212,12 @@ def main_eval(): |
206 | 212 | default=30, |
207 | 213 | help="Number of samples to evaluate, default is 30", |
208 | 214 | ) |
| 215 | + parser.add_argument( |
| 216 | + "--max_iter", |
| 217 | + type=int, |
| 218 | + default=3, |
| 219 | + help="Max iterations of reflection. Default is 3. It will overwrite the one in config yaml file.", |
| 220 | + ) |
209 | 221 | parser.add_argument( |
210 | 222 | "--output_dir", |
211 | 223 | type=str, |
@@ -233,6 +245,7 @@ def main_eval(): |
233 | 245 | dataset=args.dataset, |
234 | 246 | output_root=args.output_dir, |
235 | 247 | pre_num=args.pre_num, |
| 248 | + max_iter=args.max_iter, |
236 | 249 | skip_load=args.skip_load, |
237 | 250 | flag=args.flag, |
238 | 251 | ) |
|
0 commit comments