|
2 | 2 | import dataclasses |
3 | 3 | from dataclasses import dataclass, field |
4 | 4 | import json |
5 | | -from typing import List, Optional, Dict |
| 5 | +from typing import List, Optional, Dict, Set, Tuple, FrozenSet |
6 | 6 | import dacite |
7 | 7 | from inspect import cleandoc |
8 | 8 | import mako.template |
9 | 9 | import yaml |
| 10 | +import re |
| 11 | +from enum import Enum |
10 | 12 | from importlib import resources |
11 | 13 |
|
12 | 14 | from helm.benchmark.presentation.taxonomy_info import TaxonomyInfo |
|
27 | 29 | _ADAPTER_SPEC_FILENAME = "adapter_spec.py" |
28 | 30 | _ADAPTER_SPEC_CLASS_NAME = "AdapterSpec" |
29 | 31 |
|
| 32 | +VALID_SPLITS: Set[str] = {"test", "valid", "__all__"} |
| 33 | +TEMPLATE_VARIABLE_PATTERN = re.compile(r"^\$\{[A-Za-z_][A-Za-z0-9_]*\}$") |
| 34 | + |
30 | 35 |
|
31 | 36 | @dataclass(frozen=True) |
32 | 37 | class Field: |
@@ -273,3 +278,368 @@ def read_schema(schema_path: str) -> Schema: |
273 | 278 | if schema.adapter: |
274 | 279 | hwarn(f"The `adapter` field is deprecated and should be removed from schema file {schema_path}") |
275 | 280 | return dataclasses.replace(schema, adapter=get_adapter_fields()) |
| 281 | + |
| 282 | + |
| 283 | +class ValidationSeverity(str, Enum): |
| 284 | + """Severity level for validation messages.""" |
| 285 | + |
| 286 | + ERROR = "error" |
| 287 | + WARNING = "warning" |
| 288 | + |
| 289 | + |
| 290 | +@dataclass(frozen=True) |
| 291 | +class SchemaValidationMessage: |
| 292 | + """Represents a single validation issue found in a schema.""" |
| 293 | + |
| 294 | + severity: ValidationSeverity |
| 295 | + message: str |
| 296 | + schema_path: Optional[str] = None |
| 297 | + location: Optional[str] = None |
| 298 | + |
| 299 | + def __str__(self) -> str: |
| 300 | + parts = [] |
| 301 | + if self.schema_path: |
| 302 | + parts.append(f"[{self.schema_path}]") |
| 303 | + parts.append(f"[{self.severity.value.upper()}]") |
| 304 | + if self.location: |
| 305 | + parts.append(f"at {self.location}:") |
| 306 | + parts.append(self.message) |
| 307 | + return " ".join(parts) |
| 308 | + |
| 309 | + |
| 310 | +class SchemaValidationError(ValueError): |
| 311 | + """Exception raised when schema validation fails with errors.""" |
| 312 | + |
| 313 | + def __init__(self, messages: List[SchemaValidationMessage]): |
| 314 | + self.messages = messages |
| 315 | + error_messages = [msg for msg in messages if msg.severity == ValidationSeverity.ERROR] |
| 316 | + super().__init__( |
| 317 | + f"Schema validation failed with {len(error_messages)} error(s):\n" |
| 318 | + + "\n".join(str(msg) for msg in error_messages) |
| 319 | + ) |
| 320 | + |
| 321 | + |
| 322 | +def is_template_variable(value: str) -> bool: |
| 323 | + """Check if a value is a valid template variable of the form ${var_name}.""" |
| 324 | + if not value: |
| 325 | + return False |
| 326 | + return bool(TEMPLATE_VARIABLE_PATTERN.match(value)) |
| 327 | + |
| 328 | + |
| 329 | +def _detect_cycles_in_subgroups( |
| 330 | + run_groups: List[RunGroup], |
| 331 | + name_to_run_group: Dict[str, RunGroup], |
| 332 | +) -> List[Tuple[str, List[str]]]: |
| 333 | + """Detect all cycles in the subgroup graph using 3-color DFS.""" |
| 334 | + UNVISITED, VISITING, VISITED = 0, 1, 2 |
| 335 | + |
| 336 | + valid_run_groups = [rg for rg in run_groups if rg.name and rg.name.strip()] |
| 337 | + visit_status: Dict[str, int] = {rg.name: UNVISITED for rg in valid_run_groups} |
| 338 | + |
| 339 | + cycles: List[Tuple[str, List[str]]] = [] |
| 340 | + reported_cycle_nodes: Set[FrozenSet[str]] = set() |
| 341 | + |
| 342 | + def dfs(node_name: str, path: List[str]) -> None: |
| 343 | + if not node_name or node_name not in visit_status: |
| 344 | + return |
| 345 | + |
| 346 | + if visit_status[node_name] == VISITING: |
| 347 | + if node_name in path: |
| 348 | + cycle_start_idx = path.index(node_name) |
| 349 | + cycle_path = path[cycle_start_idx:] + [node_name] |
| 350 | + cycle_nodes = frozenset(cycle_path[:-1]) |
| 351 | + |
| 352 | + if cycle_nodes not in reported_cycle_nodes: |
| 353 | + reported_cycle_nodes.add(cycle_nodes) |
| 354 | + cycles.append((node_name, cycle_path)) |
| 355 | + return |
| 356 | + |
| 357 | + if visit_status[node_name] == VISITED: |
| 358 | + return |
| 359 | + |
| 360 | + visit_status[node_name] = VISITING |
| 361 | + path.append(node_name) |
| 362 | + |
| 363 | + if node_name in name_to_run_group: |
| 364 | + run_group = name_to_run_group[node_name] |
| 365 | + subgroups = run_group.subgroups or [] |
| 366 | + for subgroup_name in subgroups: |
| 367 | + dfs(subgroup_name, path) |
| 368 | + |
| 369 | + path.pop() |
| 370 | + visit_status[node_name] = VISITED |
| 371 | + |
| 372 | + for run_group in valid_run_groups: |
| 373 | + if visit_status.get(run_group.name) == UNVISITED: |
| 374 | + dfs(run_group.name, []) |
| 375 | + |
| 376 | + return cycles |
| 377 | + |
| 378 | + |
| 379 | +def validate_schema( |
| 380 | + schema: Schema, |
| 381 | + *, |
| 382 | + schema_path: Optional[str] = None, |
| 383 | + strict: bool = True, |
| 384 | + check_parent_child_partition: bool = False, |
| 385 | + check_orphan_children: bool = False, |
| 386 | +) -> List[SchemaValidationMessage]: |
| 387 | + """Validate a Schema object and return a list of validation messages. |
| 388 | +
|
| 389 | + If strict=True, raises SchemaValidationError when validation errors are found. |
| 390 | + """ |
| 391 | + messages: List[SchemaValidationMessage] = [] |
| 392 | + |
| 393 | + def add_error(message: str, location: Optional[str] = None) -> None: |
| 394 | + messages.append( |
| 395 | + SchemaValidationMessage( |
| 396 | + severity=ValidationSeverity.ERROR, |
| 397 | + message=message, |
| 398 | + schema_path=schema_path, |
| 399 | + location=location, |
| 400 | + ) |
| 401 | + ) |
| 402 | + |
| 403 | + def add_warning(message: str, location: Optional[str] = None) -> None: |
| 404 | + messages.append( |
| 405 | + SchemaValidationMessage( |
| 406 | + severity=ValidationSeverity.WARNING, |
| 407 | + message=message, |
| 408 | + schema_path=schema_path, |
| 409 | + location=location, |
| 410 | + ) |
| 411 | + ) |
| 412 | + |
| 413 | + run_groups = schema.run_groups or [] |
| 414 | + metric_groups = schema.metric_groups or [] |
| 415 | + metrics = schema.metrics or [] |
| 416 | + perturbations = schema.perturbations or [] |
| 417 | + |
| 418 | + defined_run_group_names: Set[str] = set(schema.name_to_run_group.keys()) |
| 419 | + defined_metric_group_names: Set[str] = set(schema.name_to_metric_group.keys()) |
| 420 | + defined_metric_names: Set[str] = set(schema.name_to_metric.keys()) |
| 421 | + defined_perturbation_names: Set[str] = set(schema.name_to_perturbation.keys()) |
| 422 | + |
| 423 | + # Check for empty names |
| 424 | + for i, run_group in enumerate(run_groups): |
| 425 | + if not run_group.name or not run_group.name.strip(): |
| 426 | + add_error(f"Run group at index {i} has empty or whitespace-only name", location=f"run_groups[{i}]") |
| 427 | + |
| 428 | + for i, metric_group in enumerate(metric_groups): |
| 429 | + if not metric_group.name or not metric_group.name.strip(): |
| 430 | + add_error(f"Metric group at index {i} has empty or whitespace-only name", location=f"metric_groups[{i}]") |
| 431 | + |
| 432 | + for i, metric in enumerate(metrics): |
| 433 | + if not metric.name or not metric.name.strip(): |
| 434 | + add_error(f"Metric at index {i} has empty or whitespace-only name", location=f"metrics[{i}]") |
| 435 | + |
| 436 | + # Check for duplicate names |
| 437 | + seen_run_group_names: Set[str] = set() |
| 438 | + for run_group in run_groups: |
| 439 | + if run_group.name and run_group.name in seen_run_group_names: |
| 440 | + add_error(f"Duplicate run_group name: '{run_group.name}'", location=f"run_groups[{run_group.name}]") |
| 441 | + if run_group.name: |
| 442 | + seen_run_group_names.add(run_group.name) |
| 443 | + |
| 444 | + seen_metric_group_names: Set[str] = set() |
| 445 | + for metric_group in metric_groups: |
| 446 | + if metric_group.name and metric_group.name in seen_metric_group_names: |
| 447 | + add_error( |
| 448 | + f"Duplicate metric_group name: '{metric_group.name}'", location=f"metric_groups[{metric_group.name}]" |
| 449 | + ) |
| 450 | + if metric_group.name: |
| 451 | + seen_metric_group_names.add(metric_group.name) |
| 452 | + |
| 453 | + seen_metric_names: Set[str] = set() |
| 454 | + for metric in metrics: |
| 455 | + if metric.name and metric.name in seen_metric_names: |
| 456 | + add_error(f"Duplicate metric name: '{metric.name}'", location=f"metrics[{metric.name}]") |
| 457 | + if metric.name: |
| 458 | + seen_metric_names.add(metric.name) |
| 459 | + |
| 460 | + seen_perturbation_names: Set[str] = set() |
| 461 | + for i, perturbation in enumerate(perturbations): |
| 462 | + if perturbation.name and perturbation.name in seen_perturbation_names: |
| 463 | + add_error( |
| 464 | + f"Duplicate perturbation name: '{perturbation.name}'", location=f"perturbations[{perturbation.name}]" |
| 465 | + ) |
| 466 | + if perturbation.name: |
| 467 | + seen_perturbation_names.add(perturbation.name) |
| 468 | + |
| 469 | + # Validate run_group.subgroups references |
| 470 | + for run_group in run_groups: |
| 471 | + if not run_group.name: |
| 472 | + continue |
| 473 | + |
| 474 | + subgroups = run_group.subgroups or [] |
| 475 | + for i, subgroup_name in enumerate(subgroups): |
| 476 | + if not subgroup_name: |
| 477 | + add_error( |
| 478 | + f"Empty subgroup reference at index {i}", location=f"run_groups[{run_group.name}].subgroups[{i}]" |
| 479 | + ) |
| 480 | + elif subgroup_name not in defined_run_group_names: |
| 481 | + add_error( |
| 482 | + f"Subgroup '{subgroup_name}' is not defined as a run_group", |
| 483 | + location=f"run_groups[{run_group.name}].subgroups[{i}]", |
| 484 | + ) |
| 485 | + |
| 486 | + # Validate run_group.metric_groups references |
| 487 | + for run_group in run_groups: |
| 488 | + if not run_group.name: |
| 489 | + continue |
| 490 | + |
| 491 | + mg_list = run_group.metric_groups or [] |
| 492 | + for i, metric_group_name in enumerate(mg_list): |
| 493 | + if not metric_group_name: |
| 494 | + add_error( |
| 495 | + f"Empty metric_group reference at index {i}", |
| 496 | + location=f"run_groups[{run_group.name}].metric_groups[{i}]", |
| 497 | + ) |
| 498 | + elif metric_group_name not in defined_metric_group_names: |
| 499 | + add_error( |
| 500 | + f"Metric group '{metric_group_name}' is not defined", |
| 501 | + location=f"run_groups[{run_group.name}].metric_groups[{i}]", |
| 502 | + ) |
| 503 | + |
| 504 | + hidden_mgs = run_group.subgroup_metric_groups_hidden or [] |
| 505 | + for i, hidden_mg_name in enumerate(hidden_mgs): |
| 506 | + if hidden_mg_name and hidden_mg_name not in defined_metric_group_names: |
| 507 | + add_error( |
| 508 | + f"Hidden metric group '{hidden_mg_name}' is not defined", |
| 509 | + location=f"run_groups[{run_group.name}].subgroup_metric_groups_hidden[{i}]", |
| 510 | + ) |
| 511 | + |
| 512 | + # Validate metric_group.metrics entries |
| 513 | + for metric_group in metric_groups: |
| 514 | + if not metric_group.name: |
| 515 | + continue |
| 516 | + |
| 517 | + metrics_list = metric_group.metrics or [] |
| 518 | + for i, metric_matcher in enumerate(metrics_list): |
| 519 | + location = f"metric_groups[{metric_group.name}].metrics[{i}]" |
| 520 | + |
| 521 | + metric_name = metric_matcher.name |
| 522 | + if not metric_name: |
| 523 | + add_error("Empty metric name", location=f"{location}.name") |
| 524 | + elif not is_template_variable(metric_name) and metric_name not in defined_metric_names: |
| 525 | + add_error( |
| 526 | + f"Metric '{metric_name}' is not defined and is not a template variable", location=f"{location}.name" |
| 527 | + ) |
| 528 | + |
| 529 | + split = metric_matcher.split |
| 530 | + if not split: |
| 531 | + add_error("Empty split value", location=f"{location}.split") |
| 532 | + elif not is_template_variable(split) and split not in VALID_SPLITS: |
| 533 | + add_error( |
| 534 | + f"Split '{split}' is not valid. Must be one of {sorted(VALID_SPLITS)} " |
| 535 | + f"or a template variable like ${{main_split}}", |
| 536 | + location=f"{location}.split", |
| 537 | + ) |
| 538 | + |
| 539 | + pert_name = metric_matcher.perturbation_name |
| 540 | + if pert_name is not None and pert_name: |
| 541 | + if not is_template_variable(pert_name) and pert_name not in defined_perturbation_names: |
| 542 | + add_error(f"Perturbation '{pert_name}' is not defined", location=f"{location}.perturbation_name") |
| 543 | + |
| 544 | + # Detect circular references |
| 545 | + cycles = _detect_cycles_in_subgroups(run_groups, schema.name_to_run_group) |
| 546 | + for cycle_start, cycle_path in cycles: |
| 547 | + cycle_str = " -> ".join(cycle_path) |
| 548 | + add_error( |
| 549 | + f"Circular reference detected in subgroups: {cycle_str}", location=f"run_groups[{cycle_start}].subgroups" |
| 550 | + ) |
| 551 | + |
| 552 | + # Optional: Parent/Child partition check |
| 553 | + if check_parent_child_partition: |
| 554 | + for run_group in run_groups: |
| 555 | + if not run_group.name: |
| 556 | + continue |
| 557 | + |
| 558 | + has_subgroups = len(run_group.subgroups or []) > 0 |
| 559 | + has_metric_groups = len(run_group.metric_groups or []) > 0 |
| 560 | + |
| 561 | + if has_subgroups and has_metric_groups: |
| 562 | + add_warning( |
| 563 | + f"Run group '{run_group.name}' has both subgroups and metric_groups", |
| 564 | + location=f"run_groups[{run_group.name}]", |
| 565 | + ) |
| 566 | + |
| 567 | + # Optional: Orphan children check |
| 568 | + if check_orphan_children: |
| 569 | + referenced_as_subgroup: Set[str] = set() |
| 570 | + for run_group in run_groups: |
| 571 | + for subgroup_name in run_group.subgroups or []: |
| 572 | + referenced_as_subgroup.add(subgroup_name) |
| 573 | + |
| 574 | + for run_group in run_groups: |
| 575 | + if not run_group.name: |
| 576 | + continue |
| 577 | + |
| 578 | + mg_count = len(run_group.metric_groups or []) |
| 579 | + sg_count = len(run_group.subgroups or []) |
| 580 | + is_child = mg_count > 0 and sg_count == 0 |
| 581 | + |
| 582 | + if is_child and run_group.name not in referenced_as_subgroup: |
| 583 | + add_warning( |
| 584 | + f"Child run_group '{run_group.name}' is not referenced by any parent", |
| 585 | + location=f"run_groups[{run_group.name}]", |
| 586 | + ) |
| 587 | + |
| 588 | + if strict: |
| 589 | + error_messages = [msg for msg in messages if msg.severity == ValidationSeverity.ERROR] |
| 590 | + if error_messages: |
| 591 | + raise SchemaValidationError(messages) |
| 592 | + |
| 593 | + return messages |
| 594 | + |
| 595 | + |
| 596 | +def validate_schema_file( |
| 597 | + schema_path: str, |
| 598 | + *, |
| 599 | + strict: bool = True, |
| 600 | + **kwargs, |
| 601 | +) -> List[SchemaValidationMessage]: |
| 602 | + """Convenience function to validate a schema file directly.""" |
| 603 | + try: |
| 604 | + schema = read_schema(schema_path) |
| 605 | + except FileNotFoundError: |
| 606 | + msg = SchemaValidationMessage( |
| 607 | + severity=ValidationSeverity.ERROR, |
| 608 | + message=f"Schema file not found: {schema_path}", |
| 609 | + schema_path=schema_path, |
| 610 | + ) |
| 611 | + if strict: |
| 612 | + raise SchemaValidationError([msg]) |
| 613 | + return [msg] |
| 614 | + except yaml.YAMLError as e: |
| 615 | + msg = SchemaValidationMessage( |
| 616 | + severity=ValidationSeverity.ERROR, |
| 617 | + message=f"Invalid YAML syntax: {e}", |
| 618 | + schema_path=schema_path, |
| 619 | + ) |
| 620 | + if strict: |
| 621 | + raise SchemaValidationError([msg]) |
| 622 | + return [msg] |
| 623 | + except Exception as e: |
| 624 | + msg = SchemaValidationMessage( |
| 625 | + severity=ValidationSeverity.ERROR, |
| 626 | + message=f"Failed to load schema: {type(e).__name__}: {e}", |
| 627 | + schema_path=schema_path, |
| 628 | + ) |
| 629 | + if strict: |
| 630 | + raise SchemaValidationError([msg]) |
| 631 | + return [msg] |
| 632 | + |
| 633 | + return validate_schema(schema, schema_path=schema_path, strict=strict, **kwargs) |
| 634 | + |
| 635 | + |
| 636 | +def get_all_schema_paths() -> List[str]: |
| 637 | + """Get paths to all schema YAML files included in the HELM package.""" |
| 638 | + schema_package = resources.files(SCHEMA_YAML_PACKAGE) |
| 639 | + schema_paths = [] |
| 640 | + |
| 641 | + for item in schema_package.iterdir(): |
| 642 | + if item.name.startswith("schema_") and item.name.endswith(".yaml"): |
| 643 | + schema_paths.append(str(schema_package.joinpath(item.name))) |
| 644 | + |
| 645 | + return sorted(schema_paths) |
0 commit comments