Skip to content

Commit e24876f

Browse files
Add validation for schema YAML files (#2528) (#4049)
1 parent 4b66642 commit e24876f

File tree

4 files changed

+871
-2
lines changed

4 files changed

+871
-2
lines changed

src/helm/benchmark/presentation/schema.py

Lines changed: 371 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import dataclasses
33
from dataclasses import dataclass, field
44
import json
5-
from typing import List, Optional, Dict
5+
from typing import List, Optional, Dict, Set, Tuple, FrozenSet
66
import dacite
77
from inspect import cleandoc
88
import mako.template
99
import yaml
10+
import re
11+
from enum import Enum
1012
from importlib import resources
1113

1214
from helm.benchmark.presentation.taxonomy_info import TaxonomyInfo
@@ -27,6 +29,9 @@
2729
_ADAPTER_SPEC_FILENAME = "adapter_spec.py"
2830
_ADAPTER_SPEC_CLASS_NAME = "AdapterSpec"
2931

32+
VALID_SPLITS: Set[str] = {"test", "valid", "__all__"}
33+
TEMPLATE_VARIABLE_PATTERN = re.compile(r"^\$\{[A-Za-z_][A-Za-z0-9_]*\}$")
34+
3035

3136
@dataclass(frozen=True)
3237
class Field:
@@ -273,3 +278,368 @@ def read_schema(schema_path: str) -> Schema:
273278
if schema.adapter:
274279
hwarn(f"The `adapter` field is deprecated and should be removed from schema file {schema_path}")
275280
return dataclasses.replace(schema, adapter=get_adapter_fields())
281+
282+
283+
class ValidationSeverity(str, Enum):
284+
"""Severity level for validation messages."""
285+
286+
ERROR = "error"
287+
WARNING = "warning"
288+
289+
290+
@dataclass(frozen=True)
291+
class SchemaValidationMessage:
292+
"""Represents a single validation issue found in a schema."""
293+
294+
severity: ValidationSeverity
295+
message: str
296+
schema_path: Optional[str] = None
297+
location: Optional[str] = None
298+
299+
def __str__(self) -> str:
300+
parts = []
301+
if self.schema_path:
302+
parts.append(f"[{self.schema_path}]")
303+
parts.append(f"[{self.severity.value.upper()}]")
304+
if self.location:
305+
parts.append(f"at {self.location}:")
306+
parts.append(self.message)
307+
return " ".join(parts)
308+
309+
310+
class SchemaValidationError(ValueError):
311+
"""Exception raised when schema validation fails with errors."""
312+
313+
def __init__(self, messages: List[SchemaValidationMessage]):
314+
self.messages = messages
315+
error_messages = [msg for msg in messages if msg.severity == ValidationSeverity.ERROR]
316+
super().__init__(
317+
f"Schema validation failed with {len(error_messages)} error(s):\n"
318+
+ "\n".join(str(msg) for msg in error_messages)
319+
)
320+
321+
322+
def is_template_variable(value: str) -> bool:
323+
"""Check if a value is a valid template variable of the form ${var_name}."""
324+
if not value:
325+
return False
326+
return bool(TEMPLATE_VARIABLE_PATTERN.match(value))
327+
328+
329+
def _detect_cycles_in_subgroups(
330+
run_groups: List[RunGroup],
331+
name_to_run_group: Dict[str, RunGroup],
332+
) -> List[Tuple[str, List[str]]]:
333+
"""Detect all cycles in the subgroup graph using 3-color DFS."""
334+
UNVISITED, VISITING, VISITED = 0, 1, 2
335+
336+
valid_run_groups = [rg for rg in run_groups if rg.name and rg.name.strip()]
337+
visit_status: Dict[str, int] = {rg.name: UNVISITED for rg in valid_run_groups}
338+
339+
cycles: List[Tuple[str, List[str]]] = []
340+
reported_cycle_nodes: Set[FrozenSet[str]] = set()
341+
342+
def dfs(node_name: str, path: List[str]) -> None:
343+
if not node_name or node_name not in visit_status:
344+
return
345+
346+
if visit_status[node_name] == VISITING:
347+
if node_name in path:
348+
cycle_start_idx = path.index(node_name)
349+
cycle_path = path[cycle_start_idx:] + [node_name]
350+
cycle_nodes = frozenset(cycle_path[:-1])
351+
352+
if cycle_nodes not in reported_cycle_nodes:
353+
reported_cycle_nodes.add(cycle_nodes)
354+
cycles.append((node_name, cycle_path))
355+
return
356+
357+
if visit_status[node_name] == VISITED:
358+
return
359+
360+
visit_status[node_name] = VISITING
361+
path.append(node_name)
362+
363+
if node_name in name_to_run_group:
364+
run_group = name_to_run_group[node_name]
365+
subgroups = run_group.subgroups or []
366+
for subgroup_name in subgroups:
367+
dfs(subgroup_name, path)
368+
369+
path.pop()
370+
visit_status[node_name] = VISITED
371+
372+
for run_group in valid_run_groups:
373+
if visit_status.get(run_group.name) == UNVISITED:
374+
dfs(run_group.name, [])
375+
376+
return cycles
377+
378+
379+
def validate_schema(
380+
schema: Schema,
381+
*,
382+
schema_path: Optional[str] = None,
383+
strict: bool = True,
384+
check_parent_child_partition: bool = False,
385+
check_orphan_children: bool = False,
386+
) -> List[SchemaValidationMessage]:
387+
"""Validate a Schema object and return a list of validation messages.
388+
389+
If strict=True, raises SchemaValidationError when validation errors are found.
390+
"""
391+
messages: List[SchemaValidationMessage] = []
392+
393+
def add_error(message: str, location: Optional[str] = None) -> None:
394+
messages.append(
395+
SchemaValidationMessage(
396+
severity=ValidationSeverity.ERROR,
397+
message=message,
398+
schema_path=schema_path,
399+
location=location,
400+
)
401+
)
402+
403+
def add_warning(message: str, location: Optional[str] = None) -> None:
404+
messages.append(
405+
SchemaValidationMessage(
406+
severity=ValidationSeverity.WARNING,
407+
message=message,
408+
schema_path=schema_path,
409+
location=location,
410+
)
411+
)
412+
413+
run_groups = schema.run_groups or []
414+
metric_groups = schema.metric_groups or []
415+
metrics = schema.metrics or []
416+
perturbations = schema.perturbations or []
417+
418+
defined_run_group_names: Set[str] = set(schema.name_to_run_group.keys())
419+
defined_metric_group_names: Set[str] = set(schema.name_to_metric_group.keys())
420+
defined_metric_names: Set[str] = set(schema.name_to_metric.keys())
421+
defined_perturbation_names: Set[str] = set(schema.name_to_perturbation.keys())
422+
423+
# Check for empty names
424+
for i, run_group in enumerate(run_groups):
425+
if not run_group.name or not run_group.name.strip():
426+
add_error(f"Run group at index {i} has empty or whitespace-only name", location=f"run_groups[{i}]")
427+
428+
for i, metric_group in enumerate(metric_groups):
429+
if not metric_group.name or not metric_group.name.strip():
430+
add_error(f"Metric group at index {i} has empty or whitespace-only name", location=f"metric_groups[{i}]")
431+
432+
for i, metric in enumerate(metrics):
433+
if not metric.name or not metric.name.strip():
434+
add_error(f"Metric at index {i} has empty or whitespace-only name", location=f"metrics[{i}]")
435+
436+
# Check for duplicate names
437+
seen_run_group_names: Set[str] = set()
438+
for run_group in run_groups:
439+
if run_group.name and run_group.name in seen_run_group_names:
440+
add_error(f"Duplicate run_group name: '{run_group.name}'", location=f"run_groups[{run_group.name}]")
441+
if run_group.name:
442+
seen_run_group_names.add(run_group.name)
443+
444+
seen_metric_group_names: Set[str] = set()
445+
for metric_group in metric_groups:
446+
if metric_group.name and metric_group.name in seen_metric_group_names:
447+
add_error(
448+
f"Duplicate metric_group name: '{metric_group.name}'", location=f"metric_groups[{metric_group.name}]"
449+
)
450+
if metric_group.name:
451+
seen_metric_group_names.add(metric_group.name)
452+
453+
seen_metric_names: Set[str] = set()
454+
for metric in metrics:
455+
if metric.name and metric.name in seen_metric_names:
456+
add_error(f"Duplicate metric name: '{metric.name}'", location=f"metrics[{metric.name}]")
457+
if metric.name:
458+
seen_metric_names.add(metric.name)
459+
460+
seen_perturbation_names: Set[str] = set()
461+
for i, perturbation in enumerate(perturbations):
462+
if perturbation.name and perturbation.name in seen_perturbation_names:
463+
add_error(
464+
f"Duplicate perturbation name: '{perturbation.name}'", location=f"perturbations[{perturbation.name}]"
465+
)
466+
if perturbation.name:
467+
seen_perturbation_names.add(perturbation.name)
468+
469+
# Validate run_group.subgroups references
470+
for run_group in run_groups:
471+
if not run_group.name:
472+
continue
473+
474+
subgroups = run_group.subgroups or []
475+
for i, subgroup_name in enumerate(subgroups):
476+
if not subgroup_name:
477+
add_error(
478+
f"Empty subgroup reference at index {i}", location=f"run_groups[{run_group.name}].subgroups[{i}]"
479+
)
480+
elif subgroup_name not in defined_run_group_names:
481+
add_error(
482+
f"Subgroup '{subgroup_name}' is not defined as a run_group",
483+
location=f"run_groups[{run_group.name}].subgroups[{i}]",
484+
)
485+
486+
# Validate run_group.metric_groups references
487+
for run_group in run_groups:
488+
if not run_group.name:
489+
continue
490+
491+
mg_list = run_group.metric_groups or []
492+
for i, metric_group_name in enumerate(mg_list):
493+
if not metric_group_name:
494+
add_error(
495+
f"Empty metric_group reference at index {i}",
496+
location=f"run_groups[{run_group.name}].metric_groups[{i}]",
497+
)
498+
elif metric_group_name not in defined_metric_group_names:
499+
add_error(
500+
f"Metric group '{metric_group_name}' is not defined",
501+
location=f"run_groups[{run_group.name}].metric_groups[{i}]",
502+
)
503+
504+
hidden_mgs = run_group.subgroup_metric_groups_hidden or []
505+
for i, hidden_mg_name in enumerate(hidden_mgs):
506+
if hidden_mg_name and hidden_mg_name not in defined_metric_group_names:
507+
add_error(
508+
f"Hidden metric group '{hidden_mg_name}' is not defined",
509+
location=f"run_groups[{run_group.name}].subgroup_metric_groups_hidden[{i}]",
510+
)
511+
512+
# Validate metric_group.metrics entries
513+
for metric_group in metric_groups:
514+
if not metric_group.name:
515+
continue
516+
517+
metrics_list = metric_group.metrics or []
518+
for i, metric_matcher in enumerate(metrics_list):
519+
location = f"metric_groups[{metric_group.name}].metrics[{i}]"
520+
521+
metric_name = metric_matcher.name
522+
if not metric_name:
523+
add_error("Empty metric name", location=f"{location}.name")
524+
elif not is_template_variable(metric_name) and metric_name not in defined_metric_names:
525+
add_error(
526+
f"Metric '{metric_name}' is not defined and is not a template variable", location=f"{location}.name"
527+
)
528+
529+
split = metric_matcher.split
530+
if not split:
531+
add_error("Empty split value", location=f"{location}.split")
532+
elif not is_template_variable(split) and split not in VALID_SPLITS:
533+
add_error(
534+
f"Split '{split}' is not valid. Must be one of {sorted(VALID_SPLITS)} "
535+
f"or a template variable like ${{main_split}}",
536+
location=f"{location}.split",
537+
)
538+
539+
pert_name = metric_matcher.perturbation_name
540+
if pert_name is not None and pert_name:
541+
if not is_template_variable(pert_name) and pert_name not in defined_perturbation_names:
542+
add_error(f"Perturbation '{pert_name}' is not defined", location=f"{location}.perturbation_name")
543+
544+
# Detect circular references
545+
cycles = _detect_cycles_in_subgroups(run_groups, schema.name_to_run_group)
546+
for cycle_start, cycle_path in cycles:
547+
cycle_str = " -> ".join(cycle_path)
548+
add_error(
549+
f"Circular reference detected in subgroups: {cycle_str}", location=f"run_groups[{cycle_start}].subgroups"
550+
)
551+
552+
# Optional: Parent/Child partition check
553+
if check_parent_child_partition:
554+
for run_group in run_groups:
555+
if not run_group.name:
556+
continue
557+
558+
has_subgroups = len(run_group.subgroups or []) > 0
559+
has_metric_groups = len(run_group.metric_groups or []) > 0
560+
561+
if has_subgroups and has_metric_groups:
562+
add_warning(
563+
f"Run group '{run_group.name}' has both subgroups and metric_groups",
564+
location=f"run_groups[{run_group.name}]",
565+
)
566+
567+
# Optional: Orphan children check
568+
if check_orphan_children:
569+
referenced_as_subgroup: Set[str] = set()
570+
for run_group in run_groups:
571+
for subgroup_name in run_group.subgroups or []:
572+
referenced_as_subgroup.add(subgroup_name)
573+
574+
for run_group in run_groups:
575+
if not run_group.name:
576+
continue
577+
578+
mg_count = len(run_group.metric_groups or [])
579+
sg_count = len(run_group.subgroups or [])
580+
is_child = mg_count > 0 and sg_count == 0
581+
582+
if is_child and run_group.name not in referenced_as_subgroup:
583+
add_warning(
584+
f"Child run_group '{run_group.name}' is not referenced by any parent",
585+
location=f"run_groups[{run_group.name}]",
586+
)
587+
588+
if strict:
589+
error_messages = [msg for msg in messages if msg.severity == ValidationSeverity.ERROR]
590+
if error_messages:
591+
raise SchemaValidationError(messages)
592+
593+
return messages
594+
595+
596+
def validate_schema_file(
597+
schema_path: str,
598+
*,
599+
strict: bool = True,
600+
**kwargs,
601+
) -> List[SchemaValidationMessage]:
602+
"""Convenience function to validate a schema file directly."""
603+
try:
604+
schema = read_schema(schema_path)
605+
except FileNotFoundError:
606+
msg = SchemaValidationMessage(
607+
severity=ValidationSeverity.ERROR,
608+
message=f"Schema file not found: {schema_path}",
609+
schema_path=schema_path,
610+
)
611+
if strict:
612+
raise SchemaValidationError([msg])
613+
return [msg]
614+
except yaml.YAMLError as e:
615+
msg = SchemaValidationMessage(
616+
severity=ValidationSeverity.ERROR,
617+
message=f"Invalid YAML syntax: {e}",
618+
schema_path=schema_path,
619+
)
620+
if strict:
621+
raise SchemaValidationError([msg])
622+
return [msg]
623+
except Exception as e:
624+
msg = SchemaValidationMessage(
625+
severity=ValidationSeverity.ERROR,
626+
message=f"Failed to load schema: {type(e).__name__}: {e}",
627+
schema_path=schema_path,
628+
)
629+
if strict:
630+
raise SchemaValidationError([msg])
631+
return [msg]
632+
633+
return validate_schema(schema, schema_path=schema_path, strict=strict, **kwargs)
634+
635+
636+
def get_all_schema_paths() -> List[str]:
637+
"""Get paths to all schema YAML files included in the HELM package."""
638+
schema_package = resources.files(SCHEMA_YAML_PACKAGE)
639+
schema_paths = []
640+
641+
for item in schema_package.iterdir():
642+
if item.name.startswith("schema_") and item.name.endswith(".yaml"):
643+
schema_paths.append(str(schema_package.joinpath(item.name)))
644+
645+
return sorted(schema_paths)

0 commit comments

Comments
 (0)