Skip to content

Commit 7d203c6

Browse files
committed
Experimental refactor
1 parent 9b76e78 commit 7d203c6

File tree

2 files changed

+14
-116
lines changed

2 files changed

+14
-116
lines changed

src/taskgraph/config.py

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# License, v. 2.0. If a copy of the MPL was not distributed with this
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

5-
65
import logging
76
import os
87
import sys
@@ -28,25 +27,10 @@
2827

2928
class WorkerAlias(Schema):
3029
"""Worker alias configuration."""
31-
32-
provisioner: Union[str, dict] # Can be keyed-by level
30+
provisioner: optionally_keyed_by("level", str) # type: ignore
3331
implementation: str
3432
os: str
35-
worker_type: Union[str, dict] # Can be keyed-by level, maps from "worker-type"
36-
37-
def __post_init__(self):
38-
"""Validate keyed-by fields."""
39-
# Validate provisioner can be keyed-by level
40-
if isinstance(self.provisioner, dict):
41-
validator = optionally_keyed_by("level", str)
42-
# Just validate - it will raise an error if invalid
43-
validator(self.provisioner)
44-
45-
# Validate worker_type can be keyed-by level
46-
if isinstance(self.worker_type, dict):
47-
validator = optionally_keyed_by("level", str)
48-
# Just validate - it will raise an error if invalid
49-
validator(self.worker_type)
33+
worker_type: optionally_keyed_by("level", str) # type: ignore
5034

5135

5236
class Workers(Schema, rename=None):
@@ -94,49 +78,17 @@ class GraphConfigSchema(Schema):
9478

9579
# Required fields first
9680
trust_domain: str # Maps from "trust-domain"
97-
task_priority: Union[
98-
TaskPriority, dict
99-
] # Maps from "task-priority", can be keyed-by project or level
81+
task_priority: optionally_keyed_by("project", "level", TaskPriority) # type: ignore
10082
workers: Workers
10183
taskgraph: TaskGraphConfig
10284

10385
# Optional fields
10486
docker_image_kind: Optional[str] = None # Maps from "docker-image-kind"
105-
task_deadline_after: Optional[Union[str, dict]] = (
106-
None # Maps from "task-deadline-after", can be keyed-by project
107-
)
87+
task_deadline_after: Optional[optionally_keyed_by("project", str)] = None # type: ignore
10888
task_expires_after: Optional[str] = None # Maps from "task-expires-after"
10989
# Allow extra fields for flexibility
11090
__extras__: Dict[str, Any] = msgspec.field(default_factory=dict)
11191

112-
def __post_init__(self):
113-
"""Validate keyed-by fields."""
114-
# Validate task_priority can be keyed-by project or level
115-
if isinstance(self.task_priority, dict):
116-
# Create a validator that accepts TaskPriority values
117-
def validate_priority(x):
118-
valid_priorities = [
119-
"highest",
120-
"very-high",
121-
"high",
122-
"medium",
123-
"low",
124-
"very-low",
125-
"lowest",
126-
]
127-
if x not in valid_priorities:
128-
raise ValueError(f"Invalid task priority: {x}")
129-
return x
130-
131-
validator = optionally_keyed_by("project", "level", validate_priority)
132-
# Just validate - it will raise an error if invalid
133-
validator(self.task_priority)
134-
135-
# Validate task_deadline_after can be keyed-by project
136-
if self.task_deadline_after and isinstance(self.task_deadline_after, dict):
137-
validator = optionally_keyed_by("project", str)
138-
# Just validate - it will raise an error if invalid
139-
validator(self.task_deadline_after)
14092

14193

14294
# Msgspec schema is now the main schema

src/taskgraph/util/schema.py

Lines changed: 10 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# License, v. 2.0. If a copy of the MPL was not distributed with this
33
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
44

5-
65
import pprint
7-
from typing import List
6+
from functools import reduce
7+
from typing import Dict, List, Literal, Union
88

99
import msgspec
1010

@@ -37,72 +37,18 @@ def validate_schema(schema, obj, msg_prefix):
3737
raise Exception(f"{msg_prefix}\n{str(exc)}\n{pprint.pformat(obj)}")
3838

3939

40-
def optionally_keyed_by(*arguments):
40+
def UnionTypes(*types):
41+
"""Use `functools.reduce` to simulate `Union[*allowed_types]` on older
42+
Python versions.
4143
"""
42-
Mark a schema value as optionally keyed by any of a number of fields. The
43-
schema is the last argument, and the remaining fields are taken to be the
44-
field names. For example:
44+
return reduce(lambda a, b: Union[a, b], types)
4545

46-
'some-value': optionally_keyed_by(
47-
'test-platform', 'build-platform',
48-
Any('a', 'b', 'c'))
4946

50-
The resulting schema will allow nesting of `by-test-platform` and
51-
`by-build-platform` in either order.
52-
"""
53-
schema = arguments[-1]
47+
def optionally_keyed_by(*arguments):
48+
_type = arguments[-1]
5449
fields = arguments[:-1]
55-
56-
def validator(obj):
57-
if isinstance(obj, dict) and len(obj) == 1:
58-
k, v = list(obj.items())[0]
59-
if k.startswith("by-") and k[len("by-") :] in fields:
60-
res = {}
61-
for kk, vv in v.items():
62-
try:
63-
res[kk] = validator(vv)
64-
except Exception as e:
65-
raise ValueError(f"Error in {k}.{kk}: {str(e)}") from e
66-
return res
67-
elif k.startswith("by-"):
68-
# Unknown by-field
69-
raise ValueError(f"Unknown key {k}")
70-
# Validate against the schema
71-
if isinstance(schema, type) and issubclass(schema, Schema):
72-
return schema.validate(obj)
73-
elif schema is str:
74-
# String validation
75-
if not isinstance(obj, str):
76-
raise TypeError(f"Expected string, got {type(obj).__name__}")
77-
return obj
78-
elif schema is int:
79-
# Int validation
80-
if not isinstance(obj, int):
81-
raise TypeError(f"Expected int, got {type(obj).__name__}")
82-
return obj
83-
elif isinstance(schema, type):
84-
# Type validation for built-in types
85-
if not isinstance(obj, schema):
86-
raise TypeError(f"Expected {schema.__name__}, got {type(obj).__name__}")
87-
return obj
88-
elif callable(schema):
89-
# Other callable validators
90-
try:
91-
return schema(obj)
92-
except:
93-
raise
94-
else:
95-
# Simple type validation
96-
if not isinstance(obj, schema):
97-
raise TypeError(
98-
f"Expected {getattr(schema, '__name__', str(schema))}, got {type(obj).__name__}"
99-
)
100-
return obj
101-
102-
# set to assist autodoc
103-
setattr(validator, "schema", schema)
104-
setattr(validator, "fields", fields)
105-
return validator
50+
bykeys = [Literal[f"by-{field}"] for field in fields]
51+
return Union[_type, Dict[UnionTypes(*bykeys), Dict[str, _type]]]
10652

10753

10854
def resolve_keyed_by(

0 commit comments

Comments
 (0)