|
18 | 18 |
|
19 | 19 | from collections import namedtuple
|
20 | 20 | import argparse
|
| 21 | +import importlib |
| 22 | +import inspect |
21 | 23 | import logging
|
22 | 24 | import typing
|
23 | 25 |
|
@@ -174,6 +176,30 @@ def dict(self) -> dict[str, Command]:
|
174 | 176 | """Return a dict mapping command names to Command object."""
|
175 | 177 | return {c.name: c for c in self._commands}
|
176 | 178 |
|
| 179 | + def include( |
| 180 | + self, modname: str, *, package: str = "", check: bool = True |
| 181 | + ) -> None: |
| 182 | + """Import a python module to add commands to this command builder. |
| 183 | + If check is true and no new commands are added by the import, raise an |
| 184 | + error. |
| 185 | + """ |
| 186 | + if modname.startswith(".") and not package: |
| 187 | + package = "sambacc.commands" |
| 188 | + mod = importlib.import_module(modname, package=package) |
| 189 | + if not check: |
| 190 | + return |
| 191 | + loaded_fns = {c.cmd_func for c in self._commands} |
| 192 | + mod_fns = {fn for _, fn in inspect.getmembers(mod, inspect.isfunction)} |
| 193 | + if not mod_fns.intersection(loaded_fns): |
| 194 | + raise Fail(f"import from {modname} did not add any new commands") |
| 195 | + |
| 196 | + def include_multiple( |
| 197 | + self, modnames: typing.Iterable[str], *, package: str = "" |
| 198 | + ) -> None: |
| 199 | + """Run the include function on multiple module names.""" |
| 200 | + for modname in modnames: |
| 201 | + self.include(modname, package=package) |
| 202 | + |
177 | 203 |
|
178 | 204 | class Context(typing.Protocol):
|
179 | 205 | """Protocol type for CLI Context.
|
|
0 commit comments