Skip to content

Commit 5997131

Browse files
authored
Add verify-entrypoints command to verify entrypoints (#34)
* Add `verify-entrypoints` command to verify entrypoints * Add `--optional-module` option to opt-in modules to safely ignore * Add `dispatch_identifier-{type,callable,backend}` schema kinds * Add a few tests for verifying entrypoints
1 parent 58953b6 commit 5997131

File tree

4 files changed

+372
-4
lines changed

4 files changed

+372
-4
lines changed

src/spatch/__main__.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import argparse
22

3-
from .backend_utils import update_entrypoint
3+
from .backend_utils import update_entrypoint, verify_entrypoint
44

55

66
def main():
@@ -13,14 +13,42 @@ def main():
1313
update_entrypoint_cmd.add_argument(
1414
"paths", type=str, nargs="+", help="paths to the entrypoint toml files to update"
1515
)
16+
update_entrypoint_cmd.add_argument(
17+
"--verify",
18+
action="store_true",
19+
help="verify updated entrypoints",
20+
)
1621

17-
args = parser.parse_args()
22+
verify_entrypoint_cmd = subparsers.add_parser(
23+
"verify-entrypoints", help="verify the entrypoint toml file"
24+
)
25+
verify_entrypoint_cmd.add_argument(
26+
"paths", type=str, nargs="+", help="paths to the entrypoint toml files to verify"
27+
)
28+
verify_entrypoint_cmd.add_argument(
29+
"--optional-module",
30+
action="append",
31+
type=str,
32+
help="add a top-level module that may be ignored during verification "
33+
"(useful when identifiers refer to optional packages)",
34+
)
1835

36+
args = parser.parse_args()
37+
verify = False
38+
optional_modules = None
1939
if args.subcommand == "update-entrypoints":
2040
for path in args.paths:
2141
update_entrypoint(path)
42+
verify = args.verify
43+
elif args.subcommand == "verify-entrypoints":
44+
verify = True
45+
if args.optional_module:
46+
optional_modules = set(args.optional_module)
2247
else:
2348
raise RuntimeError("unreachable: subcommand not known.")
49+
if verify:
50+
for path in args.paths:
51+
verify_entrypoint(path, optional_modules=optional_modules)
2452

2553

2654
if __name__ == "__main__":

src/spatch/_spatch_example/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ The "library" contains only:
1515

1616
We then have two backends with their corresponding definitions in `backend.py`.
1717
The entry-points are `entry_point.toml` and `entry_point2.toml`. When code changes,
18-
these can be updated via `python -m spin update-entrypoints *.toml`
19-
(the necessary info is in the file itself).
18+
these can be updated via `python -m spatch update-entrypoints *.toml`
19+
(the necessary info is in the file itself) and verified with
20+
`python -m spatch verify-entrypoints *.toml`.
2021

2122
For users we have the following basic capabilities. Starting with normal
2223
type dispatching.

src/spatch/backend_utils.py

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pathlib
2+
import warnings
23
from collections.abc import Callable
34
from dataclasses import dataclass
45

@@ -226,3 +227,221 @@ def update_entrypoint(filepath: str):
226227

227228
with pathlib.Path(filepath).open(mode="w") as f:
228229
tomlkit.dump(data, f)
230+
231+
232+
def verify_entrypoint(filepath: str, optional_modules: set | None = None):
233+
try:
234+
import tomllib
235+
except ImportError:
236+
import tomli as tomllib # for Python 3.10 support
237+
238+
with pathlib.Path(filepath).open("rb") as f:
239+
data = tomllib.load(f)
240+
241+
_verify_entrypoint_dict(data, optional_modules)
242+
243+
244+
def _verify_entrypoint_dict(data: dict, optional_modules: set | None = None):
245+
from importlib import import_module
246+
247+
schema = {
248+
"name": "python_identifier",
249+
"primary_types": ["dispatch_identifier-type"],
250+
"secondary_types": ["dispatch_identifier-type"],
251+
"requires_opt_in": bool,
252+
"higher_priority_than?": ["python_identifier"],
253+
"lower_priority_than?": ["python_identifier"],
254+
"functions": {
255+
"auto-generation?": {
256+
"backend": "dispatch_identifier-backend",
257+
"modules?": "modules",
258+
},
259+
"defaults?": {
260+
"function?": "dispatch_identifier-callable",
261+
"should_run?": "dispatch_identifier-callable",
262+
"additional_docs?": str,
263+
"uses_context?": bool,
264+
},
265+
},
266+
}
267+
function_schema = {
268+
"function": "dispatch_identifier-callable",
269+
"should_run?": "dispatch_identifier-callable",
270+
"additional_docs?": str,
271+
"uses_context?": bool,
272+
}
273+
274+
def to_path_key(path):
275+
# We indicate list elements with [i], which isn't proper toml
276+
path_key = ".".join(f'"{key}"' if "." in key or ":" in key else key for key in path)
277+
return path_key.replace(".[", "[")
278+
279+
def handle_bool(path_key, val):
280+
if not isinstance(val, bool):
281+
raise TypeError(f"{path_key} = {val} value is not a bool; got type {type(val)}")
282+
283+
def handle_str(path_key, val):
284+
if not isinstance(val, str):
285+
raise TypeError(f"{path_key} = {val} value is not a str; got type {type(val)}")
286+
287+
def handle_python_identifier(path_key, val):
288+
handle_str(path_key, val)
289+
if not val.isidentifier():
290+
raise ValueError(f"{path_key} = {val} value is not a valid Python identifier")
291+
292+
def handle_dispatch_identifier(path_key, val, path):
293+
handle_str(path_key, val)
294+
try:
295+
from_identifier(val)
296+
except ModuleNotFoundError as exc:
297+
module_name = val.split(":", 1)[0].split(".", 1)[0]
298+
if optional_modules and module_name in optional_modules:
299+
warnings.warn(
300+
f"{path_key} = {val} identifier not found: {exc.args[0]}",
301+
UserWarning,
302+
len(path) + 5,
303+
)
304+
return False
305+
raise ValueError(f"{path_key} = {val} identifier not found") from exc
306+
except AttributeError as exc:
307+
raise ValueError(f"{path_key} = {val} identifier not found") from exc
308+
return True
309+
310+
def handle_dispatch_identifier_type(path_key, val):
311+
reified_val = from_identifier(val)
312+
if not isinstance(reified_val, type):
313+
raise TypeError(f"{path_key} = {val} value must be a type (such as a class)")
314+
315+
def handle_dispatch_identifier_callable(path_key, val):
316+
reified_val = from_identifier(val)
317+
if not callable(reified_val):
318+
raise TypeError(f"{path_key} = {val} value must be callable")
319+
320+
def handle_dispatch_identifier_backend(path_key, val, backend_name):
321+
reified_val = from_identifier(val)
322+
if not isinstance(reified_val, BackendImplementation):
323+
# Is this too strict?
324+
raise TypeError(f"{path_key} = {val} value must be a BackendImplementation object")
325+
326+
if reified_val.name != backend_name:
327+
raise ValueError(
328+
f"{path_key} = {val} backend name does not match the name "
329+
f"in the toml file: {reified_val.name!r} != {backend_name!r}"
330+
)
331+
332+
def handle_modules(path_key, val, path):
333+
if isinstance(val, str):
334+
val = [val]
335+
elif not isinstance(val, list):
336+
raise TypeError(f"{path_key} = {val} value is not a str or list; got type {type(val)}")
337+
for i, module_name in enumerate(val):
338+
inner_path_key = f"{path_key}[{i}]"
339+
handle_str(inner_path_key, module_name)
340+
try:
341+
import_module(module_name)
342+
except ModuleNotFoundError as exc:
343+
mod_name = module_name.split(".", 1)[0]
344+
if optional_modules and mod_name in optional_modules:
345+
warnings.warn(
346+
f"{inner_path_key} = {module_name} module not found",
347+
UserWarning,
348+
len(path) + 5,
349+
)
350+
else:
351+
raise ValueError(f"{inner_path_key} = {module_name} module not found") from exc
352+
353+
def check_schema(schema, data, backend_name, path=()):
354+
# Show possible misspellings with a warning
355+
schema_keys = {key.removesuffix("?") for key in schema}
356+
if extra_keys := data.keys() - schema_keys:
357+
path_key = to_path_key(path)
358+
extra_keys = ", ".join(sorted(extra_keys))
359+
warnings.warn(
360+
f'"{path_key}" section has extra keys: {extra_keys}',
361+
UserWarning,
362+
len(path) + 4,
363+
)
364+
365+
for schema_key, schema_val in schema.items():
366+
key = schema_key.removesuffix("?")
367+
path_key = to_path_key((*path, key))
368+
if len(key) != len(schema_key): # optional key
369+
if key not in data:
370+
continue
371+
elif key not in data:
372+
raise KeyError(f"Missing required key: {path_key}")
373+
374+
val = data[key]
375+
if schema_val is bool:
376+
handle_bool(path_key, val)
377+
elif schema_val is str:
378+
handle_str(path_key, val)
379+
elif isinstance(schema_val, dict):
380+
if not isinstance(val, dict):
381+
raise TypeError(f"{path_key} value is not a dict; got type {type(val)}")
382+
check_schema(schema_val, val, backend_name, (*path, key))
383+
elif isinstance(schema_val, list):
384+
if not isinstance(val, list):
385+
raise TypeError(f"{path_key} value is not a list; got type {type(val)}")
386+
val_as_dict = {f"[{i}]": x for i, x in enumerate(val)}
387+
schema_as_dict = dict.fromkeys(val_as_dict, schema_val[0])
388+
check_schema(schema_as_dict, val_as_dict, backend_name, (*path, key))
389+
elif schema_val == "python_identifier":
390+
handle_python_identifier(path_key, val)
391+
elif schema_val.startswith("dispatch_identifier"):
392+
if not handle_dispatch_identifier(path_key, val, path):
393+
continue
394+
if schema_val.endswith("-type"):
395+
handle_dispatch_identifier_type(path_key, val)
396+
elif schema_val.endswith("-callable"):
397+
handle_dispatch_identifier_callable(path_key, val)
398+
elif schema_val.endswith("-backend"):
399+
handle_dispatch_identifier_backend(path_key, val, backend_name)
400+
elif "-" in schema_val:
401+
raise RuntimeError(f"unreachable: unknown schema: {schema_val}")
402+
elif schema_val == "modules":
403+
handle_modules(path_key, val, path)
404+
else:
405+
raise RuntimeError(f"unreachable: unknown schema: {schema_val}")
406+
407+
if not isinstance(data, dict):
408+
raise TypeError(f"toml data must be a dict; got type {type(data)}")
409+
if not isinstance(data.get("functions"), dict):
410+
raise TypeError(f"functions value is not a dict; got type {type(data.get('functions'))}")
411+
412+
backend_name = data.get("name", "<unknown>")
413+
414+
# Check everything except schema for dispatched functions
415+
fixed_functions_schema = {key.removesuffix("?") for key in schema.get("functions", {})}
416+
fixed_data = dict(data)
417+
fixed_data["functions"] = {
418+
key: val for key, val in data["functions"].items() if key in fixed_functions_schema
419+
}
420+
check_schema(schema, fixed_data, backend_name)
421+
422+
# Now check the dispatched functions
423+
dynamic_functions_data = dict(data)
424+
dynamic_functions_data = {
425+
"functions": {
426+
key: val for key, val in data["functions"].items() if key not in fixed_functions_schema
427+
}
428+
}
429+
dynamic_functions_schema = {
430+
"functions": dict.fromkeys(
431+
dynamic_functions_data["functions"],
432+
function_schema,
433+
),
434+
}
435+
check_schema(dynamic_functions_schema, dynamic_functions_data, backend_name)
436+
437+
# And a small hack to check the *keys* of dispatched functions
438+
function_keys_data = {
439+
"functions": {key: key for key in dynamic_functions_data["functions"]},
440+
}
441+
function_keys_schema = {
442+
"functions": dict.fromkeys(
443+
dynamic_functions_data["functions"],
444+
"dispatch_identifier-callable",
445+
),
446+
}
447+
check_schema(function_keys_schema, function_keys_data, backend_name)

0 commit comments

Comments
 (0)