|
1 | 1 | import logging |
2 | | -import time |
3 | 2 | from collections.abc import MutableMapping |
4 | 3 | from concurrent.futures import wait |
5 | 4 | from pathlib import Path |
|
14 | 13 | cli, |
15 | 14 | click_parameter_decorators_from_typed_dict, |
16 | 15 | ) |
| 16 | +from ..models import TaskConfig |
17 | 17 |
|
18 | 18 | log = logging.getLogger(__name__) |
19 | 19 |
|
@@ -95,27 +95,136 @@ def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: s |
95 | 95 | return args_arr |
96 | 96 |
|
97 | 97 |
|
98 | | -@cli.command() |
99 | | -@click_parameter_decorators_from_typed_dict(BatchCliTypedDict) |
100 | | -def BatchCli(): |
101 | | - ctx = click.get_current_context() |
102 | | - batch_config = ctx.default_map |
| 98 | +def build_task_from_config(cmd_name: str, config_dict: dict[str, Any]) -> TaskConfig | None: |
103 | 99 |
|
104 | | - runner = CliRunner() |
| 100 | + collected_tasks = [] |
| 101 | + original_run = None |
105 | 102 |
|
106 | | - args_arr = build_sub_cmd_args(batch_config) |
| 103 | + try: |
| 104 | + from ..interface import benchmark_runner |
| 105 | + |
| 106 | + original_run = benchmark_runner.run |
| 107 | + |
| 108 | + def collect_task_wrapper(tasks: list[TaskConfig], task_label: str | None = None): # noqa: ARG001 |
| 109 | + collected_tasks.extend(tasks) |
| 110 | + return True |
| 111 | + |
| 112 | + benchmark_runner.run = collect_task_wrapper |
| 113 | + |
| 114 | + # build CLI parameters |
| 115 | + args = [cmd_name] |
| 116 | + bool_options = { |
| 117 | + "drop_old": True, |
| 118 | + "load": True, |
| 119 | + "search_serial": True, |
| 120 | + "search_concurrent": True, |
| 121 | + "dry_run": False, |
| 122 | + "custom_dataset_use_shuffled": True, |
| 123 | + "custom_dataset_with_gt": True, |
| 124 | + } |
| 125 | + |
| 126 | + def format_option(key: str, value: Any): |
| 127 | + opt_name = key.replace("_", "-") |
| 128 | + |
| 129 | + if key in bool_options: |
| 130 | + return format_bool_option(opt_name, value, skip=False) |
| 131 | + |
| 132 | + if key.startswith("skip_"): |
| 133 | + raw_key = key[5:] |
| 134 | + raw_opt = raw_key.replace("_", "-") |
| 135 | + return format_bool_option(raw_opt, value, skip=True, raw_key=raw_key) |
| 136 | + |
| 137 | + return [f"--{opt_name}", str(value)] |
| 138 | + |
| 139 | + def format_bool_option(opt_name: str, value: Any, skip: bool = False, raw_key: str | None = None): |
| 140 | + if isinstance(value, bool): |
| 141 | + if skip: |
| 142 | + if bool_options.get(raw_key, False): |
| 143 | + return [f"--skip-{opt_name}"] if value else [f"--{opt_name}"] |
| 144 | + return [f"--{opt_name}", str(value)] |
| 145 | + if value: |
| 146 | + return [f"--{opt_name}"] |
| 147 | + if bool_options.get(opt_name.replace("-", "_"), False): |
| 148 | + return [f"--skip-{opt_name}"] |
| 149 | + return [] |
| 150 | + return [f"--{opt_name}", str(value)] |
| 151 | + |
| 152 | + for k, v in config_dict.items(): |
| 153 | + args.extend(format_option(k, v)) |
| 154 | + |
| 155 | + # call CLI command (this will trigger collect_task_wrapper) |
| 156 | + runner = CliRunner() |
| 157 | + result = runner.invoke(cli, args, catch_exceptions=False) |
107 | 158 |
|
108 | | - for args in args_arr: |
109 | | - log.info(f"got batch config: {' '.join(args)}") |
| 159 | + if result.exception: |
| 160 | + log.error(f"Failed to build task for {cmd_name}: {result.exception}") |
| 161 | + return None |
110 | 162 |
|
111 | | - for args in args_arr: |
112 | | - result = runner.invoke(cli, args) |
113 | | - time.sleep(5) |
| 163 | + if collected_tasks: |
| 164 | + return collected_tasks[0] |
| 165 | + return None # noqa: TRY300 |
114 | 166 |
|
115 | | - from ..interface import global_result_future |
| 167 | + except Exception: |
| 168 | + log.exception("Error building task from config") |
| 169 | + return None |
| 170 | + finally: |
| 171 | + if original_run is not None: |
| 172 | + from ..interface import benchmark_runner |
116 | 173 |
|
117 | | - if global_result_future: |
118 | | - wait([global_result_future]) |
| 174 | + benchmark_runner.run = original_run |
119 | 175 |
|
120 | | - if result.exception: |
121 | | - log.exception(f"failed to run sub command: {args[0]}", exc_info=result.exception) |
| 176 | + |
| 177 | +@cli.command() |
| 178 | +@click_parameter_decorators_from_typed_dict(BatchCliTypedDict) |
| 179 | +def BatchCli(): |
| 180 | + ctx = click.get_current_context() |
| 181 | + batch_config = ctx.default_map |
| 182 | + |
| 183 | + from ..interface import benchmark_runner, global_result_future |
| 184 | + |
| 185 | + # collect all tasks |
| 186 | + all_tasks: list[TaskConfig] = [] |
| 187 | + task_labels: set[str] = set() |
| 188 | + |
| 189 | + for cmd_name, cmd_config_list in batch_config.items(): |
| 190 | + for config_dict in cmd_config_list: |
| 191 | + log.info(f"Building task for {cmd_name} with config: {config_dict.get('task_label', 'N/A')}") |
| 192 | + |
| 193 | + # collect task_label from config |
| 194 | + if "task_label" in config_dict: |
| 195 | + task_labels.add(config_dict["task_label"]) |
| 196 | + |
| 197 | + # TaskConfig |
| 198 | + task = build_task_from_config(cmd_name, config_dict) |
| 199 | + if task: |
| 200 | + all_tasks.append(task) |
| 201 | + log.info(f"Successfully built task: {task.db.value} - {task.case_config.case_id.name}") |
| 202 | + else: |
| 203 | + log.warning(f"Failed to build task for {cmd_name}") |
| 204 | + |
| 205 | + if not all_tasks: |
| 206 | + log.error("No tasks were built from the batch config file") |
| 207 | + return |
| 208 | + |
| 209 | + if len(task_labels) == 1: |
| 210 | + task_label = task_labels.pop() |
| 211 | + log.info(f"Using shared task_label from config: {task_label}") |
| 212 | + elif len(task_labels) > 1: |
| 213 | + task_label = next(iter(task_labels)) |
| 214 | + log.warning(f"Multiple task_labels found in config, using the first one: {task_label}") |
| 215 | + else: |
| 216 | + from datetime import datetime |
| 217 | + |
| 218 | + task_label = f"batch_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
| 219 | + log.info(f"No task_label found in config, using generated one: {task_label}") |
| 220 | + |
| 221 | + log.info(f"Running {len(all_tasks)} tasks with shared task_label: {task_label}") |
| 222 | + |
| 223 | + benchmark_runner.run(all_tasks, task_label) |
| 224 | + |
| 225 | + if global_result_future: |
| 226 | + log.info("Waiting for all tasks to complete...") |
| 227 | + wait([global_result_future]) |
| 228 | + log.info("All tasks completed successfully") |
| 229 | + else: |
| 230 | + log.warning("No global_result_future found, tasks may be running in background") |
0 commit comments