-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy pathmain_mcts.py
More file actions
executable file
·214 lines (171 loc) · 7.23 KB
/
main_mcts.py
File metadata and controls
executable file
·214 lines (171 loc) · 7.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import atexit
import logging
import shutil
import sys
import backend
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed, wait, FIRST_COMPLETED
from agent.mcts_agent import MCTSAgent as Agent
from interpreter.interpreter_parallel import Interpreter
from search.journal import Journal
from search.node import Node
from omegaconf import OmegaConf
from rich.columns import Columns
from rich.console import Group
from rich.padding import Padding
from rich.panel import Panel
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
TextColumn,
TimeRemainingColumn,
)
from rich.text import Text
from rich.markdown import Markdown
from rich.status import Status
from rich.tree import Tree
from utils.config_mcts import load_task_desc, prep_agent_workspace, save_run, load_cfg
class VerboseFilter(logging.Filter):
"""
Filter (remove) logs that have verbose attribute set to True
"""
def filter(self, record):
return not (hasattr(record, "verbose") and record.verbose)
def journal_to_rich_tree(journal: Journal):
best_node = journal.get_best_node()
def append_rec(node: Node, tree):
if node.is_buggy:
s = "[red]◍ bug"
else:
style = "bold " if node is best_node else ""
if node is best_node:
s = f"[{style}green]● {node.metric.value:.3f} (best)"
else:
s = f"[{style}green]● {node.metric.value:.3f}"
subtree = tree.add(s)
for child in node.children:
append_rec(child, subtree)
tree = Tree("[bold blue]Solution tree")
for n in journal.draft_nodes:
append_rec(n, tree)
return tree
def journal_to_string_tree(journal: Journal) -> str:
best_node = journal.get_best_node()
tree_str = "Solution tree\n"
def append_rec(node: Node, level: int):
nonlocal tree_str
indent = " " * level
if node.is_buggy:
s = f"{indent}◍ bug (ID: {node.id})\n"
else:
# support for multiple markers; atm only "best" is supported
markers = []
if node is best_node:
markers.append("best")
marker_str = " & ".join(markers)
if marker_str and node.metric.value:
s = f"{indent}● {node.metric.value:.3f} ({marker_str}) (ID: {node.id})\n"
else:
s = f"{indent}● {node.metric.value:.3f} (ID: {node.id})\n"
tree_str += s
for child in node.children:
append_rec(child, level + 1)
for n in journal.draft_nodes:
append_rec(n, 0)
return tree_str
def run():
cfg = load_cfg()
log_format = "[%(asctime)s] %(levelname)s: %(message)s"
logging.basicConfig(
level=getattr(logging, cfg.log_level.upper()), format=log_format, handlers=[]
)
# dont want info logs from httpx
httpx_logger: logging.Logger = logging.getLogger("httpx")
httpx_logger.setLevel(logging.WARNING)
logger = logging.getLogger("ml-master")
# save logs to files as well, using same format
cfg.log_dir.mkdir(parents=True, exist_ok=True)
# we'll have a normal log file and verbose log file. Only normal to console
file_handler = logging.FileHandler(cfg.log_dir / "ml-master.log")
file_handler.setFormatter(logging.Formatter(log_format))
file_handler.addFilter(VerboseFilter())
verbose_file_handler = logging.FileHandler(cfg.log_dir / "ml-master.verbose.log")
verbose_file_handler.setFormatter(logging.Formatter(log_format))
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(logging.Formatter(log_format))
console_handler.addFilter(VerboseFilter())
logger.addHandler(file_handler)
logger.addHandler(verbose_file_handler)
logger.addHandler(console_handler)
logger.info(f'Starting run "{cfg.exp_name}"')
task_desc = load_task_desc(cfg)
task_desc_str = backend.compile_prompt_to_md(task_desc)
with Status("Preparing agent workspace (copying and extracting files) ..."):
prep_agent_workspace(cfg)
def cleanup():
if global_step == 0:
shutil.rmtree(cfg.workspace_dir)
if cfg.agent.steerable_reasoning == True:
logger.warning("Steerable reasoning is enabled, please make sure your open sourced model api support `client.compeletion.create()`, otherwise the process may fail")
if "gpt" in cfg.agent.code.model or "gemini" in cfg.agent.code.model or "claude" in cfg.agent.code.model:
logger.warning("Steerable reasoning does not support close sourced models, please set steerable reasoning to false")
raise ValueError("Steerable reasoning does not support close sourced models, please set steerable reasoning to false")
if cfg.agent.check_format == True:
logger.warning("Check format is enabled, please make sure you have launched the server, or this step will be skipped")
atexit.register(cleanup)
journal = Journal()
agent = Agent(
task_desc=task_desc,
cfg=cfg,
journal=journal,
)
interpreter = Interpreter(
cfg.workspace_dir, **OmegaConf.to_container(cfg.exec), cfg=cfg # type: ignore
)
global_step = len(journal)
prog = Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(bar_width=20),
MofNCompleteColumn(),
TimeRemainingColumn(),
)
status = Status("[green]Generating code...")
prog.add_task("Progress:", total=cfg.agent.steps, completed=global_step)
def exec_callback(*args, **kwargs):
status.update("[magenta]Executing code...")
res = interpreter.run(*args, **kwargs)
status.update("[green]Generating code...")
return res
def step_task(node=None):
if node:
logger.info(f"[step_task] Processing node: {node.id}")
else:
logger.info(f"[step_task] Processing virtual root node.")
return agent.step(exec_callback=exec_callback, node=node)
max_workers = cfg.agent.search.parallel_search_num
total_steps = cfg.agent.steps
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = {executor.submit(step_task) for _ in range(min(max_workers, total_steps))}
completed = 0
lock = threading.Lock()
while completed <= total_steps:
done, _ = wait(futures, return_when=FIRST_COMPLETED)
for fut in done:
futures.remove(fut)
try:
cur_node = fut.result()
logger.info(f"current node count is {completed}, current node.id is {cur_node.id}")
except Exception as e:
logger.exception(f"Exception during step_task execution: {e}")
cur_node = None
with lock:
save_run(cfg, journal)
completed = len(journal)-1. # Exclude virtual node
if completed == total_steps:
logger.info(journal_to_string_tree(journal))
if completed + len(futures) < total_steps:
futures.add(executor.submit(step_task, cur_node))
interpreter.cleanup_session(-1)
if __name__ == "__main__":
run()