Skip to content

Commit 1b37f95

Browse files
fix: updated based on review
1 parent 006a6e3 commit 1b37f95

File tree

14 files changed

+286
-273
lines changed

14 files changed

+286
-273
lines changed

docs/concepts/transforms.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ about the state of the tasks at given points. Here is an example:
110110
from taskgraph.util.schema import Schema
111111
112112
class MySchema(Schema):
113-
foo: str # Required field
113+
foo: str # Required field
114114
bar: Optional[bool] = None # Optional field
115115
116116
transforms = TransformSequence()

src/taskgraph/config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
]
2424

2525

26-
class WorkerAlias(Schema):
26+
class WorkerAliasSchema(Schema):
2727
"""Worker alias configuration."""
2828

2929
provisioner: optionally_keyed_by("level", str) # type: ignore
@@ -32,10 +32,10 @@ class WorkerAlias(Schema):
3232
worker_type: optionally_keyed_by("level", str) # type: ignore
3333

3434

35-
class Workers(Schema, rename=None):
35+
class WorkersSchema(Schema, rename=None):
3636
"""Workers configuration."""
3737

38-
aliases: Dict[str, WorkerAlias]
38+
aliases: Dict[str, WorkerAliasSchema]
3939

4040

4141
class Repository(Schema, forbid_unknown_fields=False):
@@ -58,7 +58,7 @@ class RunConfig(Schema):
5858
use_caches: Optional[Union[bool, List[str]]] = None # Maps from "use-caches"
5959

6060

61-
class TaskGraphConfig(Schema):
61+
class TaskGraphSchema(Schema):
6262
"""Taskgraph specific configuration."""
6363

6464
# Required fields first
@@ -82,8 +82,8 @@ class GraphConfigSchema(Schema, forbid_unknown_fields=False):
8282
# Required fields first
8383
trust_domain: str # Maps from "trust-domain"
8484
task_priority: optionally_keyed_by("project", "level", TaskPriority) # type: ignore
85-
workers: Workers
86-
taskgraph: TaskGraphConfig
85+
workers: WorkersSchema
86+
taskgraph: TaskGraphSchema
8787

8888
# Optional fields
8989
docker_image_kind: Optional[str] = None # Maps from "docker-image-kind"
@@ -158,7 +158,6 @@ def kinds_dir(self):
158158

159159
def validate_graph_config(config):
160160
"""Validate graph configuration using msgspec."""
161-
# With rename="kebab", msgspec handles the conversion automatically
162161
validate_schema(GraphConfigSchema, config, "Invalid graph configuration:")
163162

164163

src/taskgraph/parameters.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class ParameterMismatch(Exception):
2929
"""Raised when a parameters.yml has extra or missing parameters."""
3030

3131

32-
class CodeReviewConfig(Schema):
32+
class CodeReviewSchema(Schema):
3333
"""Code review configuration."""
3434

3535
# Required field
@@ -73,10 +73,10 @@ class BaseSchema(Schema):
7373
tasks_for: str
7474

7575
# Optional fields
76-
next_version: Optional[str]
77-
optimize_strategies: Optional[str]
78-
version: Optional[str]
79-
code_review: Optional[CodeReviewConfig] = None
76+
next_version: Optional[str] = None
77+
optimize_strategies: Optional[str] = None
78+
version: Optional[str] = None
79+
code_review: Optional[CodeReviewSchema] = None
8080

8181

8282
def get_contents(path):
@@ -243,30 +243,74 @@ def check(self):
243243

244244
if self.strict:
245245
# Strict mode: validate against schema and check for extra fields
246-
# Get all valid field names from the base schema
246+
# Get all valid field names from the base schema and extensions
247247
schema_fields = {
248248
f.encode_name for f in msgspec.structs.fields(BaseSchema)
249249
}
250250

251+
# Add fields from extension schemas
252+
for ext_schema in _schema_extensions:
253+
if isinstance(ext_schema, type) and issubclass(
254+
ext_schema, msgspec.Struct
255+
):
256+
schema_fields.update(
257+
{f.encode_name for f in msgspec.structs.fields(ext_schema)}
258+
)
259+
251260
# Check for extra fields
252261
extra_fields = set(kebab_params.keys()) - schema_fields
253262
if extra_fields:
254263
raise ParameterMismatch(
255264
f"Invalid parameters: Extra fields not allowed: {extra_fields}"
256265
)
257266

258-
# Validate all parameters against the schema
259-
msgspec.convert(kebab_params, BaseSchema)
267+
# Validate base schema fields only (filter out extension fields)
268+
base_fields = {
269+
f.encode_name for f in msgspec.structs.fields(BaseSchema)
270+
}
271+
base_params = {
272+
k: v for k, v in kebab_params.items() if k in base_fields
273+
}
274+
msgspec.convert(base_params, BaseSchema)
275+
276+
# Also validate against extension schemas
277+
for ext_schema in _schema_extensions:
278+
if isinstance(ext_schema, type) and issubclass(
279+
ext_schema, msgspec.Struct
280+
):
281+
# Only validate fields that belong to this extension
282+
ext_fields = {
283+
f.encode_name for f in msgspec.structs.fields(ext_schema)
284+
}
285+
ext_params = {
286+
k: v for k, v in kebab_params.items() if k in ext_fields
287+
}
288+
if ext_params:
289+
msgspec.convert(ext_params, ext_schema)
260290
else:
261-
# Non-strict mode: only validate fields that exist in the schema
262-
# Filter to only include fields defined in the schema
291+
# Non-strict mode: only validate fields that exist in the schemas
292+
# Filter to only include fields defined in the base schema
263293
schema_fields = {
264294
f.encode_name for f in msgspec.structs.fields(BaseSchema)
265295
}
266296
filtered_params = {
267297
k: v for k, v in kebab_params.items() if k in schema_fields
268298
}
269299
msgspec.convert(filtered_params, BaseSchema)
300+
301+
# Also validate extension schemas in non-strict mode
302+
for ext_schema in _schema_extensions:
303+
if isinstance(ext_schema, type) and issubclass(
304+
ext_schema, msgspec.Struct
305+
):
306+
ext_fields = {
307+
f.encode_name for f in msgspec.structs.fields(ext_schema)
308+
}
309+
ext_params = {
310+
k: v for k, v in kebab_params.items() if k in ext_fields
311+
}
312+
if ext_params:
313+
msgspec.convert(ext_params, ext_schema)
270314
except (msgspec.ValidationError, msgspec.DecodeError) as e:
271315
raise ParameterMismatch(f"Invalid parameters: {e}")
272316

src/taskgraph/transforms/chunking.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from taskgraph.util.templates import substitute
1010

1111

12-
class ChunkConfig(Schema):
12+
class ChunkSchema(Schema):
1313
"""
1414
`chunk` can be used to split one task into `total-chunks`
1515
tasks, substituting `this_chunk` and `total_chunks` into any
@@ -24,13 +24,13 @@ class ChunkConfig(Schema):
2424

2525

2626
#: Schema for chunking transforms
27-
class ChunkSchema(Schema, forbid_unknown_fields=False):
27+
class ChunksSchema(Schema, forbid_unknown_fields=False):
2828
# Optional, so it can be used for a subset of tasks in a kind
29-
chunk: Optional[ChunkConfig] = None
29+
chunk: Optional[ChunkSchema] = None
3030

3131

3232
transforms = TransformSequence()
33-
transforms.add_validate(ChunkSchema)
33+
transforms.add_validate(ChunksSchema)
3434

3535

3636
@transforms.add

src/taskgraph/transforms/fetch.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
import re
1111
from dataclasses import dataclass
12-
from typing import Any, Callable, Dict, Optional
12+
from typing import Any, Callable, Dict, Optional, Union
1313

1414
import msgspec
1515

@@ -24,14 +24,6 @@
2424
CACHE_TYPE = "content.v1"
2525

2626

27-
#: Schema for fetch transforms
28-
class FetchConfig(Schema, rename=None, omit_defaults=False):
29-
"""Configuration for a fetch task type."""
30-
31-
type: str
32-
# Additional fields handled dynamically by fetch builders
33-
34-
3527
class FetchSchema(Schema):
3628
# Required fields
3729
# Name of the task.
@@ -70,7 +62,7 @@ def __post_init__(self):
7062

7163
@dataclass(frozen=True)
7264
class FetchBuilder:
73-
schema: Any # Either msgspec.Struct type or validation function
65+
schema: Union[Schema, Callable]
7466
builder: Callable
7567

7668

@@ -194,7 +186,7 @@ def make_task(config, tasks):
194186
yield task_desc
195187

196188

197-
class GPGSignatureConfig(Schema):
189+
class GPGSignatureSchema(Schema):
198190
"""GPG signature verification configuration."""
199191

200192
# URL where GPG signature document can be obtained. Can contain the
@@ -204,7 +196,7 @@ class GPGSignatureConfig(Schema):
204196
key_path: str
205197

206198

207-
class StaticUrlFetchConfig(Schema, rename="kebab"):
199+
class StaticUrlFetchSchema(Schema):
208200
"""Configuration for static-url fetch type."""
209201

210202
type: str
@@ -215,7 +207,7 @@ class StaticUrlFetchConfig(Schema, rename="kebab"):
215207
# Size of the downloaded entity, in bytes.
216208
size: int
217209
# GPG signature verification.
218-
gpg_signature: Optional[GPGSignatureConfig] = None
210+
gpg_signature: Optional[GPGSignatureSchema] = None
219211
# The name to give to the generated artifact. Defaults to the file
220212
# portion of the URL. Using a different extension converts the
221213
# archive to the given type. Only conversion to .tar.zst is supported.
@@ -233,7 +225,7 @@ class StaticUrlFetchConfig(Schema, rename="kebab"):
233225
# it is important to update the digest data used to compute cache hits.
234226

235227

236-
@fetch_builder("static-url", StaticUrlFetchConfig)
228+
@fetch_builder("static-url", StaticUrlFetchSchema)
237229
def create_fetch_url_task(config, name, fetch):
238230
artifact_name = fetch.get("artifact-name")
239231
if not artifact_name:
@@ -296,7 +288,7 @@ def create_fetch_url_task(config, name, fetch):
296288
}
297289

298290

299-
class GitFetchConfig(Schema):
291+
class GitFetchSchema(Schema):
300292
"""Configuration for git fetch type."""
301293

302294
type: str
@@ -312,7 +304,7 @@ class GitFetchConfig(Schema):
312304
ssh_key: Optional[str] = None
313305

314306

315-
@fetch_builder("git", GitFetchConfig)
307+
@fetch_builder("git", GitFetchSchema)
316308
def create_git_fetch_task(config, name, fetch):
317309
path_prefix = fetch.get("path-prefix")
318310
if not path_prefix:

src/taskgraph/transforms/from_deps.py

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,19 @@
1313

1414
from copy import deepcopy
1515
from textwrap import dedent
16-
from typing import Any, Dict, List, Optional, Union
17-
18-
import msgspec
16+
from typing import Any, Dict, List, Literal, Optional, Union
1917

2018
from taskgraph.transforms.base import TransformSequence
2119
from taskgraph.util.attributes import attrmatch
2220
from taskgraph.util.dependencies import GROUP_BY_MAP, get_dependencies
2321
from taskgraph.util.schema import Schema, validate_schema
2422
from taskgraph.util.set_name import SET_NAME_MAP
2523

24+
SetNameType = Literal["strip-kind", "retain-kind"]
25+
GroupByType = Literal["single", "all", "attribute"]
2626

27-
class FetchEntry(Schema, rename=None):
28-
"""A fetch entry for an artifact."""
29-
30-
artifact: str
31-
dest: Optional[str] = None
3227

33-
34-
class FromDepsConfig(Schema):
28+
class FromDepsChildSchema(Schema):
3529
# Optional fields
3630
# Limit dependencies to specified kinds (defaults to all kinds in
3731
# `kind-dependencies`).
@@ -40,16 +34,17 @@ class FromDepsConfig(Schema):
4034
# dependency of this kind will be used to derive the label
4135
# and copy attributes (if `copy-attributes` is True).
4236
kinds: Optional[List[str]] = None
43-
# UPDATE ME AND DOCS
44-
set_name: Optional[Union[str, bool, Dict[str, Any]]] = None
37+
# Set the task name using the specified function. Can be False to
38+
# disable name setting, or a string/dict specifying the function to use.
39+
set_name: Optional[Union[SetNameType, bool, Dict[SetNameType, Any]]] = None
4540
# Limit dependencies to tasks whose attributes match
4641
# using :func:`~taskgraph.util.attributes.attrmatch`.
4742
with_attributes: Optional[Dict[str, Union[List[Any], str]]] = None
4843
# Group cross-kind dependencies using the given group-by
4944
# function. One task will be created for each group. If not
5045
# specified, the 'single' function will be used which creates
5146
# a new task for each individual dependency.
52-
group_by: Optional[Union[str, Dict[str, Any]]] = None
47+
group_by: Optional[Union[GroupByType, Dict[GroupByType, Any]]] = None
5348
# If True, copy attributes from the dependency matching the
5449
# first kind in the `kinds` list (whether specified explicitly
5550
# or taken from `kind-dependencies`).
@@ -64,35 +59,12 @@ class FromDepsConfig(Schema):
6459
# `fetches` entry.
6560
fetches: Optional[Dict[str, List[Union[str, Dict[str, str]]]]] = None
6661

67-
def __post_init__(self):
68-
# Validate set_name
69-
if self.set_name is not None and self.set_name is not False:
70-
if isinstance(self.set_name, str) and self.set_name not in SET_NAME_MAP:
71-
raise msgspec.ValidationError(f"Invalid set-name: {self.set_name}")
72-
elif isinstance(self.set_name, dict):
73-
keys = list(self.set_name.keys())
74-
if len(keys) != 1 or keys[0] not in SET_NAME_MAP:
75-
raise msgspec.ValidationError(
76-
f"Invalid set-name dict: {self.set_name}"
77-
)
78-
79-
# Validate group_by
80-
if self.group_by is not None:
81-
if isinstance(self.group_by, str) and self.group_by not in GROUP_BY_MAP:
82-
raise msgspec.ValidationError(f"Invalid group-by: {self.group_by}")
83-
elif isinstance(self.group_by, dict):
84-
keys = list(self.group_by.keys())
85-
if len(keys) != 1 or keys[0] not in GROUP_BY_MAP:
86-
raise msgspec.ValidationError(
87-
f"Invalid group-by dict: {self.group_by}"
88-
)
89-
90-
91-
#: Schema for from_deps transforms
62+
63+
# Schema for from_deps transforms
9264
class FromDepsSchema(Schema, forbid_unknown_fields=False):
9365
"""Schema for from_deps transforms."""
9466

95-
from_deps: FromDepsConfig
67+
from_deps: FromDepsChildSchema
9668

9769

9870
transforms = TransformSequence()

src/taskgraph/transforms/matrix.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from taskgraph.util.templates import substitute_task_fields
1616

1717

18-
class MatrixConfig(Schema, forbid_unknown_fields=False):
18+
class MatrixChildSchema(Schema, forbid_unknown_fields=False):
1919
"""
2020
Matrix configuration for generating multiple tasks.
2121
"""
@@ -47,7 +47,7 @@ class MatrixSchema(Schema, forbid_unknown_fields=False):
4747
"""
4848

4949
name: str
50-
matrix: Optional[MatrixConfig] = None
50+
matrix: Optional[MatrixChildSchema] = None
5151

5252

5353
transforms = TransformSequence()

0 commit comments

Comments
 (0)