Skip to content

Commit 1d23c98

Browse files
feat: added support for optional keying
1 parent 7fddbd9 commit 1d23c98

File tree

7 files changed

+163
-181
lines changed

7 files changed

+163
-181
lines changed

src/taskgraph/config.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import msgspec
1414

1515
from .util.python_path import find_object
16-
from .util.schema import Schema, validate_schema
16+
from .util.schema import Schema, optionally_keyed_by, validate_schema
1717
from .util.vcs import get_repository
1818
from .util.yaml import load_yaml
1919

@@ -29,10 +29,24 @@
2929
class WorkerAlias(Schema):
3030
"""Worker alias configuration."""
3131

32-
provisioner: Union[str, dict]
32+
provisioner: Union[str, dict] # Can be keyed-by level
3333
implementation: str
3434
os: str
35-
worker_type: Union[str, dict] # Can be keyed-by, maps from "worker-type"
35+
worker_type: Union[str, dict] # Can be keyed-by level, maps from "worker-type"
36+
37+
def __post_init__(self):
38+
"""Validate keyed-by fields."""
39+
# Validate provisioner can be keyed-by level
40+
if isinstance(self.provisioner, dict):
41+
validator = optionally_keyed_by("level", str)
42+
# Just validate - it will raise an error if invalid
43+
validator(self.provisioner)
44+
45+
# Validate worker_type can be keyed-by level
46+
if isinstance(self.worker_type, dict):
47+
validator = optionally_keyed_by("level", str)
48+
# Just validate - it will raise an error if invalid
49+
validator(self.worker_type)
3650

3751

3852
class Workers(Schema, rename=None):
@@ -82,19 +96,48 @@ class GraphConfigSchema(Schema):
8296
trust_domain: str # Maps from "trust-domain"
8397
task_priority: Union[
8498
TaskPriority, dict
85-
] # Maps from "task-priority", can be keyed-by
99+
] # Maps from "task-priority", can be keyed-by project or level
86100
workers: Workers
87101
taskgraph: TaskGraphConfig
88102

89103
# Optional fields
90104
docker_image_kind: Optional[str] = None # Maps from "docker-image-kind"
91105
task_deadline_after: Optional[Union[str, dict]] = (
92-
None # Maps from "task-deadline-after", can be keyed-by
106+
None # Maps from "task-deadline-after", can be keyed-by project
93107
)
94108
task_expires_after: Optional[str] = None # Maps from "task-expires-after"
95109
# Allow extra fields for flexibility
96110
__extras__: Dict[str, Any] = msgspec.field(default_factory=dict)
97111

112+
def __post_init__(self):
113+
"""Validate keyed-by fields."""
114+
# Validate task_priority can be keyed-by project or level
115+
if isinstance(self.task_priority, dict):
116+
# Create a validator that accepts TaskPriority values
117+
def validate_priority(x):
118+
valid_priorities = [
119+
"highest",
120+
"very-high",
121+
"high",
122+
"medium",
123+
"low",
124+
"very-low",
125+
"lowest",
126+
]
127+
if x not in valid_priorities:
128+
raise ValueError(f"Invalid task priority: {x}")
129+
return x
130+
131+
validator = optionally_keyed_by("project", "level", validate_priority)
132+
# Just validate - it will raise an error if invalid
133+
validator(self.task_priority)
134+
135+
# Validate task_deadline_after can be keyed-by project
136+
if self.task_deadline_after and isinstance(self.task_deadline_after, dict):
137+
validator = optionally_keyed_by("project", str)
138+
# Just validate - it will raise an error if invalid
139+
validator(self.task_deadline_after)
140+
98141

99142
# Msgspec schema is now the main schema
100143
graph_config_schema = GraphConfigSchema

src/taskgraph/parameters.py

Lines changed: 35 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ def extend_parameters_schema(schema, defaults_fn=None):
175175
# Store the extension schema for use during validation
176176
_schema_extensions.append(schema)
177177

178-
# Schema extension is no longer supported with msgspec.Struct inheritance
179-
# Extensions are tracked in _schema_extensions list instead
178+
# With msgspec, schema extensions are tracked in the _schema_extensions list
179+
# for validation purposes rather than being merged into a single schema
180180

181181
if defaults_fn:
182182
defaults_functions.append(defaults_fn)
@@ -240,92 +240,39 @@ def _fill_defaults(repo_root=None, **kwargs):
240240
return kwargs
241241

242242
def check(self):
243-
# For msgspec schemas, we need to validate differently
244-
if isinstance(base_schema, type) and issubclass(base_schema, msgspec.Struct):
245-
try:
246-
# Convert underscore keys to kebab-case for msgspec validation
247-
params = self.copy()
248-
# BaseSchema uses kebab-case (rename="kebab"), so we need to convert keys
249-
kebab_params = {}
250-
for k, v in params.items():
251-
# Convert underscore to kebab-case
252-
kebab_key = k.replace("_", "-")
253-
kebab_params[kebab_key] = v
254-
255-
# Handle extensions if present
256-
global _schema_extensions
257-
for ext_schema in _schema_extensions:
258-
if isinstance(ext_schema, dict):
259-
# Simple dict validation - just check if required keys exist
260-
for key in ext_schema:
261-
# Just skip validation of extensions for now
262-
pass
263-
264-
if self.strict:
265-
# Strict validation with msgspec
266-
# First check for extra fields
267-
schema_fields = {
268-
f.encode_name for f in msgspec.structs.fields(base_schema)
269-
}
270-
271-
# Add extension fields if present
272-
for ext_schema in _schema_extensions:
273-
if isinstance(ext_schema, dict):
274-
for key in ext_schema.keys():
275-
# Extract field name
276-
if hasattr(key, "key"):
277-
field_name = key.key.replace("_", "-")
278-
else:
279-
field_name = str(key).replace("_", "-")
280-
schema_fields.add(field_name)
281-
282-
extra_fields = set(kebab_params.keys()) - schema_fields
283-
if extra_fields:
284-
raise ParameterMismatch(
285-
f"Invalid parameters: Extra fields not allowed: {extra_fields}"
286-
)
287-
# Now validate the base schema fields
288-
base_fields = {
289-
f.encode_name for f in msgspec.structs.fields(base_schema)
290-
}
291-
base_params = {
292-
k: v for k, v in kebab_params.items() if k in base_fields
293-
}
294-
msgspec.convert(base_params, base_schema)
295-
else:
296-
# Non-strict: validate only the fields that exist in the schema
297-
# Filter to only schema fields
298-
schema_fields = {
299-
f.encode_name for f in msgspec.structs.fields(base_schema)
300-
}
301-
filtered_params = {
302-
k: v for k, v in kebab_params.items() if k in schema_fields
303-
}
304-
msgspec.convert(filtered_params, base_schema)
305-
except (msgspec.ValidationError, msgspec.DecodeError) as e:
306-
raise ParameterMismatch(f"Invalid parameters: {e}")
307-
else:
308-
# For non-msgspec schemas, validate using the Schema class
309-
from taskgraph.util.schema import validate_schema # noqa: PLC0415
310-
311-
try:
312-
if self.strict:
313-
validate_schema(base_schema, self.copy(), "Invalid parameters:")
314-
else:
315-
# In non-strict mode, allow extra fields
316-
if hasattr(base_schema, "allow_extra"):
317-
original_allow_extra = base_schema.allow_extra
318-
base_schema.allow_extra = True
319-
try:
320-
validate_schema(
321-
base_schema, self.copy(), "Invalid parameters:"
322-
)
323-
finally:
324-
base_schema.allow_extra = original_allow_extra
325-
else:
326-
validate_schema(base_schema, self.copy(), "Invalid parameters:")
327-
except Exception as e:
328-
raise ParameterMismatch(str(e))
243+
# Validate parameters using msgspec schema
244+
try:
245+
# Convert underscore keys to kebab-case since BaseSchema uses rename="kebab"
246+
kebab_params = {k.replace("_", "-"): v for k, v in self.items()}
247+
248+
if self.strict:
249+
# Strict mode: validate against schema and check for extra fields
250+
# Get all valid field names from the base schema
251+
schema_fields = {
252+
f.encode_name for f in msgspec.structs.fields(base_schema)
253+
}
254+
255+
# Check for extra fields
256+
extra_fields = set(kebab_params.keys()) - schema_fields
257+
if extra_fields:
258+
raise ParameterMismatch(
259+
f"Invalid parameters: Extra fields not allowed: {extra_fields}"
260+
)
261+
262+
# Validate all parameters against the schema
263+
msgspec.convert(kebab_params, base_schema)
264+
else:
265+
# Non-strict mode: only validate fields that exist in the schema
266+
# Filter to only include fields defined in the schema
267+
schema_fields = {
268+
f.encode_name for f in msgspec.structs.fields(base_schema)
269+
}
270+
filtered_params = {
271+
k: v for k, v in kebab_params.items() if k in schema_fields
272+
}
273+
msgspec.convert(filtered_params, base_schema)
274+
except (msgspec.ValidationError, msgspec.DecodeError) as e:
275+
raise ParameterMismatch(f"Invalid parameters: {e}")
329276

330277
def __getitem__(self, k):
331278
try:

src/taskgraph/transforms/docker_image.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,27 @@
3030

3131
#: Schema for docker_image transforms
3232
class DockerImageSchema(Schema):
33-
"""
34-
Schema for docker_image transforms.
35-
36-
Attributes:
37-
name: Name of the docker image.
38-
parent: Name of the parent docker image.
39-
symbol: Treeherder symbol.
40-
task_from: Relative path (from config.path) to the file the docker image was defined in.
41-
args: Arguments to use for the Dockerfile.
42-
definition: Name of the docker image definition under taskcluster/docker, when
43-
different from the docker image name.
44-
packages: List of package tasks this docker image depends on.
45-
index: Information for indexing this build so its artifacts can be discovered.
46-
cache: Whether this image should be cached based on inputs.
47-
"""
48-
4933
# Required field first
34+
# Name of the docker image.
5035
name: str
5136

5237
# Optional fields
38+
# Name of the parent docker image.
5339
parent: Optional[str] = None
40+
# Treeherder symbol.
5441
symbol: Optional[str] = None
42+
# Relative path (from config.path) to the file the docker image was defined in.
5543
task_from: Optional[str] = None
44+
# Arguments to use for the Dockerfile.
5645
args: Optional[Dict[str, str]] = None
46+
# Name of the docker image definition under taskcluster/docker, when
47+
# different from the docker image name.
5748
definition: Optional[str] = None
49+
# List of package tasks this docker image depends on.
5850
packages: Optional[List[str]] = None
59-
# For now, use Any for index since task_description_schema is not converted yet
51+
# Information for indexing this build so its artifacts can be discovered.
6052
index: Optional[Any] = None
53+
# Whether this image should be cached based on inputs.
6154
cache: Optional[bool] = None
6255

6356

src/taskgraph/transforms/fetch.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,29 @@ class FetchConfig(Schema, rename=None, omit_defaults=False):
3333

3434

3535
class FetchSchema(Schema):
36-
"""
37-
Schema for fetch transforms.
38-
39-
Attributes:
40-
name: Name of the task.
41-
task_from: Relative path (from config.path) to the file the task was defined in.
42-
description: Description of the task.
43-
expires_after: When the task expires.
44-
docker_image: Docker image configuration.
45-
fetch_alias: An alias that can be used instead of the real fetch task name in
46-
fetch stanzas for tasks.
47-
artifact_prefix: The prefix of the taskcluster artifact being uploaded.
48-
Defaults to `public/`; if it starts with something other than
49-
`public/` the artifact will require scopes to access.
50-
attributes: Task attributes.
51-
fetch: Fetch configuration with type and additional fields.
52-
"""
53-
36+
# Required fields
37+
# Name of the task.
5438
name: str
39+
# Description of the task.
5540
description: str
41+
# Fetch configuration with type and additional fields.
5642
fetch: Dict[str, Any] # Must have 'type' key, other keys depend on type
43+
44+
# Optional fields
45+
# Relative path (from config.path) to the file the task was defined in.
5746
task_from: Optional[str] = None
47+
# When the task expires.
5848
expires_after: Optional[str] = None
49+
# Docker image configuration.
5950
docker_image: Optional[Any] = None
51+
# An alias that can be used instead of the real fetch task name in
52+
# fetch stanzas for tasks.
6053
fetch_alias: Optional[str] = None
54+
# The prefix of the taskcluster artifact being uploaded.
55+
# Defaults to `public/`; if it starts with something other than
56+
# `public/` the artifact will require scopes to access.
6157
artifact_prefix: Optional[str] = None
58+
# Task attributes.
6259
attributes: Optional[Dict[str, Any]] = None
6360

6461
def __post_init__(self):

src/taskgraph/transforms/from_deps.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from taskgraph.util.set_name import SET_NAME_MAP
2525

2626

27-
# Define FetchEntry for the fetches field
2827
class FetchEntry(Schema, rename=None):
2928
"""A fetch entry for an artifact."""
3029

@@ -33,40 +32,36 @@ class FetchEntry(Schema, rename=None):
3332

3433

3534
class FromDepsConfig(Schema):
36-
"""
37-
Configuration for from-deps transforms.
38-
39-
Attributes:
40-
kinds: Limit dependencies to specified kinds (defaults to all kinds in
41-
`kind-dependencies`). The first kind in the list is the "primary" kind.
42-
The dependency of this kind will be used to derive the label
43-
and copy attributes (if `copy-attributes` is True).
44-
set_name: UPDATE ME AND DOCS. Can be a string from SET_NAME_MAP, False, None,
45-
or a dict with a SET_NAME_MAP key.
46-
with_attributes: Limit dependencies to tasks whose attributes match
47-
using :func:`~taskgraph.util.attributes.attrmatch`.
48-
group_by: Group cross-kind dependencies using the given group-by
49-
function. One task will be created for each group. If not
50-
specified, the 'single' function will be used which creates
51-
a new task for each individual dependency.
52-
copy_attributes: If True, copy attributes from the dependency matching the
53-
first kind in the `kinds` list (whether specified explicitly
54-
or taken from `kind-dependencies`).
55-
unique_kinds: If true (the default), there must be only a single unique task
56-
for each kind in a dependency group. Setting this to false
57-
disables that requirement.
58-
fetches: If present, a `fetches` entry will be added for each task
59-
dependency. Attributes of the upstream task may be used as
60-
substitution values in the `artifact` or `dest` values of the
61-
`fetches` entry.
62-
"""
63-
35+
# Optional fields
36+
# Limit dependencies to specified kinds (defaults to all kinds in
37+
# `kind-dependencies`).
38+
#
39+
# The first kind in the list is the "primary" kind. The
40+
# dependency of this kind will be used to derive the label
41+
# and copy attributes (if `copy-attributes` is True).
6442
kinds: Optional[List[str]] = None
43+
# UPDATE ME AND DOCS
6544
set_name: Optional[Union[str, bool, Dict[str, Any]]] = None
45+
# Limit dependencies to tasks whose attributes match
46+
# using :func:`~taskgraph.util.attributes.attrmatch`.
6647
with_attributes: Optional[Dict[str, Union[List[Any], str]]] = None
48+
# Group cross-kind dependencies using the given group-by
49+
# function. One task will be created for each group. If not
50+
# specified, the 'single' function will be used which creates
51+
# a new task for each individual dependency.
6752
group_by: Optional[Union[str, Dict[str, Any]]] = None
53+
# If True, copy attributes from the dependency matching the
54+
# first kind in the `kinds` list (whether specified explicitly
55+
# or taken from `kind-dependencies`).
6856
copy_attributes: Optional[bool] = None
57+
# If true (the default), there must be only a single unique task
58+
# for each kind in a dependency group. Setting this to false
59+
# disables that requirement.
6960
unique_kinds: Optional[bool] = None
61+
# If present, a `fetches` entry will be added for each task
62+
# dependency. Attributes of the upstream task may be used as
63+
# substitution values in the `artifact` or `dest` values of the
64+
# `fetches` entry.
7065
fetches: Optional[Dict[str, List[Union[str, Dict[str, str]]]]] = None
7166

7267
def __post_init__(self):

0 commit comments

Comments
 (0)