|
14 | 14 | import typing |
15 | 15 | from collections.abc import Iterator |
16 | 16 | from contextlib import contextmanager |
| 17 | +from contextlib import nullcontext |
17 | 18 | from functools import partial |
18 | 19 | from subprocess import CompletedProcess |
19 | 20 | from types import FunctionType |
20 | 21 | from types import GenericAlias |
21 | 22 | from typing import Any |
22 | 23 | from typing import cast |
| 24 | +from typing import ContextManager |
23 | 25 | from typing import TYPE_CHECKING |
24 | 26 | from typing import TypedDict |
25 | 27 |
|
@@ -248,38 +250,70 @@ def web(self) -> requests.Session: |
248 | 250 | return requests.Session() |
249 | 251 |
|
250 | 252 |
|
| 253 | +class DefaultVirtualenvConfig: |
| 254 | + """ |
| 255 | + Simple class to hold registered imports. |
| 256 | + """ |
| 257 | + |
| 258 | + _instance: DefaultVirtualenvConfig | None = None |
| 259 | + venv_config: VirtualEnvConfig |
| 260 | + |
| 261 | + def __new__(cls): |
| 262 | + """ |
| 263 | + Method that instantiates a singleton class and returns it. |
| 264 | + """ |
| 265 | + if cls._instance is None: |
| 266 | + instance = super().__new__(cls) |
| 267 | + cls._instance = instance |
| 268 | + return cls._instance |
| 269 | + |
| 270 | + @classmethod |
| 271 | + def set_default_venv_config(cls, venv_config: VirtualEnvConfig) -> None: |
| 272 | + """ |
| 273 | + Register an import. |
| 274 | + """ |
| 275 | + instance = cls._instance |
| 276 | + if instance is None: |
| 277 | + instance = cls() |
| 278 | + if venv_config and "name" not in venv_config: |
| 279 | + venv_config["name"] = "default" |
| 280 | + instance.venv_config = venv_config |
| 281 | + |
| 282 | + |
251 | 283 | class RegisteredImports: |
252 | 284 | """ |
253 | 285 | Simple class to hold registered imports. |
254 | 286 | """ |
255 | 287 |
|
256 | 288 | _instance: RegisteredImports | None = None |
257 | | - _registered_imports: list[str] |
| 289 | + _registered_imports: dict[str, VirtualEnvConfig | None] |
258 | 290 |
|
259 | 291 | def __new__(cls): |
260 | 292 | """ |
261 | 293 | Method that instantiates a singleton class and returns it. |
262 | 294 | """ |
263 | 295 | if cls._instance is None: |
264 | 296 | instance = super().__new__(cls) |
265 | | - instance._registered_imports = [] |
| 297 | + instance._registered_imports = {} |
266 | 298 | cls._instance = instance |
267 | 299 | return cls._instance |
268 | 300 |
|
269 | 301 | @classmethod |
270 | | - def register_import(cls, import_module: str) -> None: |
| 302 | + def register_import( |
| 303 | + cls, import_module: str, venv_config: VirtualEnvConfig | None = None |
| 304 | + ) -> None: |
271 | 305 | """ |
272 | 306 | Register an import. |
273 | 307 | """ |
274 | 308 | instance = cls() |
275 | 309 | if import_module not in instance._registered_imports: |
276 | | - instance._registered_imports.append(import_module) |
| 310 | + instance._registered_imports[import_module] = venv_config |
277 | 311 |
|
278 | 312 | def __iter__(self): |
279 | 313 | """ |
280 | 314 | Return an iterator of all registered imports. |
281 | 315 | """ |
282 | | - return iter(self._registered_imports) |
| 316 | + return iter(self._registered_imports.items()) |
283 | 317 |
|
284 | 318 |
|
285 | 319 | class Parser: |
@@ -370,14 +404,29 @@ def __new__(cls): |
370 | 404 | return cls._instance |
371 | 405 |
|
372 | 406 | def _process_registered_tool_modules(self): |
373 | | - for module_name in RegisteredImports(): |
374 | | - try: |
375 | | - importlib.import_module(module_name) |
376 | | - except ImportError as exc: |
377 | | - if os.environ.get("TOOLS_IGNORE_IMPORT_ERRORS", "0") == "0": |
378 | | - self.context.warn( |
379 | | - f"Could not import the registered tools module {module_name!r}: {exc}" |
380 | | - ) |
| 407 | + default_venv: VirtualEnv | ContextManager[None] |
| 408 | + default_venv_config = DefaultVirtualenvConfig().venv_config |
| 409 | + if default_venv_config: |
| 410 | + default_venv = VirtualEnv(ctx=self.context, **default_venv_config) |
| 411 | + else: |
| 412 | + default_venv = nullcontext() |
| 413 | + with default_venv: |
| 414 | + for module_name, venv_config in RegisteredImports(): |
| 415 | + venv: VirtualEnv | ContextManager[None] |
| 416 | + if venv_config: |
| 417 | + if "name" not in venv_config: |
| 418 | + venv_config["name"] = module_name |
| 419 | + venv = VirtualEnv(ctx=self.context, **venv_config) |
| 420 | + else: |
| 421 | + venv = nullcontext() |
| 422 | + with venv: |
| 423 | + try: |
| 424 | + importlib.import_module(module_name) |
| 425 | + except ImportError as exc: |
| 426 | + if os.environ.get("TOOLS_IGNORE_IMPORT_ERRORS", "0") == "0": |
| 427 | + self.context.warn( |
| 428 | + f"Could not import the registered tools module {module_name!r}: {exc}" |
| 429 | + ) |
381 | 430 |
|
382 | 431 | def parse_args(self): |
383 | 432 | """ |
@@ -471,6 +520,8 @@ def __init__(self, name, help, description=None, parent=None, venv_config=None): |
471 | 520 | GroupReference.add_command(tuple(parent + [name]), self) |
472 | 521 | parent = GroupReference()[tuple(parent)] |
473 | 522 |
|
| 523 | + if venv_config and "name" not in venv_config: |
| 524 | + venv_config["name"] = self.name |
474 | 525 | self.venv_config = venv_config or {} |
475 | 526 | self.parser = parent.subparsers.add_parser( |
476 | 527 | name.replace("_", "-"), |
@@ -634,22 +685,22 @@ def __call__(self, func, options, venv_config: VirtualEnvConfig | None = None): |
634 | 685 | kwargs[name] = getattr(options, name) |
635 | 686 |
|
636 | 687 | bound = signature.bind_partial(*args, **kwargs) |
637 | | - venv = None |
| 688 | + venv: VirtualEnv | ContextManager[None] |
638 | 689 | if venv_config: |
639 | | - venv_name = getattr(options, f"{self.name}_command") |
640 | | - venv = VirtualEnv(name=f"{self.name}.{venv_name}", ctx=self.context, **venv_config) |
| 690 | + if "name" not in venv_config: |
| 691 | + venv_config["name"] = getattr(options, f"{self.name}_command") |
| 692 | + venv = VirtualEnv(ctx=self.context, **venv_config) |
641 | 693 | elif self.venv_config: |
642 | | - venv = VirtualEnv(name=self.name, ctx=self.context, **self.venv_config) |
643 | | - if venv: |
644 | | - with venv: |
645 | | - previous_venv = self.context.venv |
646 | | - try: |
647 | | - self.context.venv = venv |
648 | | - func(self.context, *bound.args, **bound.kwargs) |
649 | | - finally: |
650 | | - self.context.venv = previous_venv |
| 694 | + venv = VirtualEnv(ctx=self.context, **self.venv_config) |
651 | 695 | else: |
652 | | - func(self.context, *bound.args, **bound.kwargs) |
| 696 | + venv = nullcontext() |
| 697 | + with venv: |
| 698 | + previous_venv = self.context.venv |
| 699 | + try: |
| 700 | + self.context.venv = venv |
| 701 | + func(self.context, *bound.args, **bound.kwargs) |
| 702 | + finally: |
| 703 | + self.context.venv = previous_venv |
653 | 704 |
|
654 | 705 |
|
655 | 706 | def command_group( |
|
0 commit comments