@@ -175,8 +175,8 @@ def extend_parameters_schema(schema, defaults_fn=None):
175175 # Store the extension schema for use during validation
176176 _schema_extensions .append (schema )
177177
178- # Schema extension is no longer supported with msgspec.Struct inheritance
179- # Extensions are tracked in _schema_extensions list instead
178+ # With msgspec, schema extensions are tracked in the _schema_extensions list
179+ # for validation purposes rather than being merged into a single schema
180180
181181 if defaults_fn :
182182 defaults_functions .append (defaults_fn )
@@ -240,92 +240,39 @@ def _fill_defaults(repo_root=None, **kwargs):
240240 return kwargs
241241
242242 def check (self ):
243- # For msgspec schemas, we need to validate differently
244- if isinstance (base_schema , type ) and issubclass (base_schema , msgspec .Struct ):
245- try :
246- # Convert underscore keys to kebab-case for msgspec validation
247- params = self .copy ()
248- # BaseSchema uses kebab-case (rename="kebab"), so we need to convert keys
249- kebab_params = {}
250- for k , v in params .items ():
251- # Convert underscore to kebab-case
252- kebab_key = k .replace ("_" , "-" )
253- kebab_params [kebab_key ] = v
254-
255- # Handle extensions if present
256- global _schema_extensions
257- for ext_schema in _schema_extensions :
258- if isinstance (ext_schema , dict ):
259- # Simple dict validation - just check if required keys exist
260- for key in ext_schema :
261- # Just skip validation of extensions for now
262- pass
263-
264- if self .strict :
265- # Strict validation with msgspec
266- # First check for extra fields
267- schema_fields = {
268- f .encode_name for f in msgspec .structs .fields (base_schema )
269- }
270-
271- # Add extension fields if present
272- for ext_schema in _schema_extensions :
273- if isinstance (ext_schema , dict ):
274- for key in ext_schema .keys ():
275- # Extract field name
276- if hasattr (key , "key" ):
277- field_name = key .key .replace ("_" , "-" )
278- else :
279- field_name = str (key ).replace ("_" , "-" )
280- schema_fields .add (field_name )
281-
282- extra_fields = set (kebab_params .keys ()) - schema_fields
283- if extra_fields :
284- raise ParameterMismatch (
285- f"Invalid parameters: Extra fields not allowed: { extra_fields } "
286- )
287- # Now validate the base schema fields
288- base_fields = {
289- f .encode_name for f in msgspec .structs .fields (base_schema )
290- }
291- base_params = {
292- k : v for k , v in kebab_params .items () if k in base_fields
293- }
294- msgspec .convert (base_params , base_schema )
295- else :
296- # Non-strict: validate only the fields that exist in the schema
297- # Filter to only schema fields
298- schema_fields = {
299- f .encode_name for f in msgspec .structs .fields (base_schema )
300- }
301- filtered_params = {
302- k : v for k , v in kebab_params .items () if k in schema_fields
303- }
304- msgspec .convert (filtered_params , base_schema )
305- except (msgspec .ValidationError , msgspec .DecodeError ) as e :
306- raise ParameterMismatch (f"Invalid parameters: { e } " )
307- else :
308- # For non-msgspec schemas, validate using the Schema class
309- from taskgraph .util .schema import validate_schema # noqa: PLC0415
310-
311- try :
312- if self .strict :
313- validate_schema (base_schema , self .copy (), "Invalid parameters:" )
314- else :
315- # In non-strict mode, allow extra fields
316- if hasattr (base_schema , "allow_extra" ):
317- original_allow_extra = base_schema .allow_extra
318- base_schema .allow_extra = True
319- try :
320- validate_schema (
321- base_schema , self .copy (), "Invalid parameters:"
322- )
323- finally :
324- base_schema .allow_extra = original_allow_extra
325- else :
326- validate_schema (base_schema , self .copy (), "Invalid parameters:" )
327- except Exception as e :
328- raise ParameterMismatch (str (e ))
243+ # Validate parameters using msgspec schema
244+ try :
245+ # Convert underscore keys to kebab-case since BaseSchema uses rename="kebab"
246+ kebab_params = {k .replace ("_" , "-" ): v for k , v in self .items ()}
247+
248+ if self .strict :
249+ # Strict mode: validate against schema and check for extra fields
250+ # Get all valid field names from the base schema
251+ schema_fields = {
252+ f .encode_name for f in msgspec .structs .fields (base_schema )
253+ }
254+
255+ # Check for extra fields
256+ extra_fields = set (kebab_params .keys ()) - schema_fields
257+ if extra_fields :
258+ raise ParameterMismatch (
259+ f"Invalid parameters: Extra fields not allowed: { extra_fields } "
260+ )
261+
262+ # Validate all parameters against the schema
263+ msgspec .convert (kebab_params , base_schema )
264+ else :
265+ # Non-strict mode: only validate fields that exist in the schema
266+ # Filter to only include fields defined in the schema
267+ schema_fields = {
268+ f .encode_name for f in msgspec .structs .fields (base_schema )
269+ }
270+ filtered_params = {
271+ k : v for k , v in kebab_params .items () if k in schema_fields
272+ }
273+ msgspec .convert (filtered_params , base_schema )
274+ except (msgspec .ValidationError , msgspec .DecodeError ) as e :
275+ raise ParameterMismatch (f"Invalid parameters: { e } " )
329276
330277 def __getitem__ (self , k ):
331278 try :
0 commit comments